mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
247
vllm_npu/multistream/ms_split.py
Normal file
247
vllm_npu/multistream/ms_split.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig
|
||||
|
||||
|
||||
def compute_split_seq_index(
|
||||
query_lens: Optional[list[int]],
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int,
|
||||
imbalance_ratio: float = 0.1,
|
||||
) -> list[int]:
|
||||
if attn_state != AscendAttentionState.DecodeOnly:
|
||||
assert query_lens is not None
|
||||
total_tokens = sum(query_lens)
|
||||
# the first index in last split
|
||||
tokens, split_index = 0, 0
|
||||
for value in query_lens:
|
||||
tokens += value
|
||||
split_index += 1
|
||||
if tokens >= total_tokens // 2:
|
||||
# check the current split index
|
||||
if abs(tokens -
|
||||
total_tokens // 2) < total_tokens * imbalance_ratio:
|
||||
return [tokens, split_index]
|
||||
# check the previous split index
|
||||
elif abs(tokens - total_tokens // 2 -
|
||||
value) < total_tokens * imbalance_ratio:
|
||||
return [tokens - value, split_index - 1]
|
||||
# fail to split if it is imbalanced
|
||||
# TODO: split tokens in seq
|
||||
else:
|
||||
return [0, 0]
|
||||
else:
|
||||
tokens = num_tokens // 2
|
||||
return [tokens, tokens]
|
||||
return [0, 0]
|
||||
|
||||
|
||||
def split_attn_tensor_type(
|
||||
input_tensor: torch.Tensor,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [input_tensor[:index], input_tensor[index:]]
|
||||
|
||||
|
||||
def split_attn_int_type(
|
||||
var: int,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [min(var, index), max(var - index, 0)]
|
||||
|
||||
|
||||
def model_input_split_v1_mla_attn(
|
||||
attn_metadata,
|
||||
_metadata_cls,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> List[Any]:
|
||||
assert 0 < ms_split_config.num_micro_batches < 3
|
||||
if attn_metadata is None:
|
||||
return [attn_metadata]
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
||||
attn_metadata.attn_state,
|
||||
attn_metadata.num_decode_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
attn_metadata.query_lens):
|
||||
return [attn_metadata]
|
||||
|
||||
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
||||
dtype=int)
|
||||
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
||||
if attn_metadata.num_prefills > 0:
|
||||
prefill_query_start_loc = np.zeros(
|
||||
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
||||
np.cumsum(attn_metadata.prefill.query_lens,
|
||||
out=prefill_query_start_loc[1:])
|
||||
|
||||
# split attn metadata
|
||||
[slot_mapping_pre,
|
||||
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
||||
token_index)
|
||||
[num_decodes_pre,
|
||||
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
||||
seq_index)
|
||||
[num_decode_tokens_pre, num_decode_tokens_post
|
||||
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
||||
[num_prefills_pre, num_prefills_post
|
||||
] = split_attn_int_type(attn_metadata.num_prefills,
|
||||
max(0, seq_index - attn_metadata.num_decodes))
|
||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
|
||||
query_start_loc_pre = query_start_loc_post = None
|
||||
if attn_metadata.query_start_loc is not None:
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
[block_table_pre,
|
||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||
seq_index)
|
||||
assert attn_metadata.attn_mask is not None
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
# should be none in decode only state
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
# chunked prefill
|
||||
if num_prefills_pre > 0:
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
||||
seq_lens_pre)].contiguous()
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
else:
|
||||
attn_state_pre = AscendAttentionState.DecodeOnly
|
||||
attn_mask_pre = None
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
from vllm_npu.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAPrefillMetadata)
|
||||
if num_prefills_pre > 0:
|
||||
# split metadata.prefill
|
||||
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
||||
attn_metadata.prefill.input_positions,
|
||||
token_index - attn_metadata.num_decode_tokens)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
[prefill_query_lens_pre, prefill_query_lens_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
||||
seq_index
|
||||
+
|
||||
1 -
|
||||
attn_metadata
|
||||
.
|
||||
num_decodes]
|
||||
prefill_query_start_loc_post = deepcopy(
|
||||
attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes:]
|
||||
) - attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes]
|
||||
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
||||
context_len_post = seq_lens_post
|
||||
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
||||
prefill_max_query_len_post = max(prefill_query_lens_post)
|
||||
prefill_pre = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_pre,
|
||||
query_lens=prefill_query_lens_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=prefill_query_start_loc_pre,
|
||||
input_positions=input_positions_pre,
|
||||
context_lens=context_len_pre,
|
||||
block_table=block_tables_pre,
|
||||
max_query_len=prefill_max_query_len_pre,
|
||||
max_seq_lens=context_len_pre.max().item(),
|
||||
)
|
||||
prefill_post = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_post,
|
||||
query_lens=prefill_query_lens_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=prefill_query_start_loc_post,
|
||||
input_positions=input_positions_post,
|
||||
context_lens=context_len_post,
|
||||
block_table=block_tables_post,
|
||||
max_query_len=prefill_max_query_len_post,
|
||||
max_seq_lens=context_len_post.max().item(),
|
||||
)
|
||||
decode_pre = attn_metadata.decode
|
||||
decode_post = None
|
||||
else:
|
||||
# prefill is None, split metadata.decode
|
||||
[input_positions_pre, input_positions_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
||||
token_index)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
||||
seq_index)
|
||||
[decode_seq_lens_pre,
|
||||
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
decode_pre = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_pre,
|
||||
block_table=block_tables_pre,
|
||||
seq_lens=decode_seq_lens_pre,
|
||||
max_seq_lens=max(decode_seq_lens_pre),
|
||||
seq_lens_list=decode_seq_lens_pre.tolist(),
|
||||
)
|
||||
decode_post = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_post,
|
||||
block_table=block_tables_post,
|
||||
seq_lens=decode_seq_lens_post,
|
||||
max_seq_lens=max(decode_seq_lens_post),
|
||||
seq_lens_list=decode_seq_lens_post.tolist(),
|
||||
)
|
||||
prefill_pre = None
|
||||
prefill_post = attn_metadata.prefill
|
||||
# construct metadata
|
||||
from vllm_npu.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
attention_metadata_pre = _metadata_cls(
|
||||
num_actual_tokens=token_index,
|
||||
num_input_tokens=token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=query_start_loc_pre,
|
||||
block_tables=block_table_pre,
|
||||
num_decodes=num_decodes_pre,
|
||||
num_prefills=num_prefills_pre,
|
||||
num_decode_tokens=num_decode_tokens_pre,
|
||||
attn_state=attn_state_pre,
|
||||
attn_mask=attn_mask_pre,
|
||||
prefill=prefill_pre,
|
||||
decode=decode_pre,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
attention_metadata_post = _metadata_cls(
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=query_start_loc_post,
|
||||
block_tables=block_table_post,
|
||||
num_decodes=num_decodes_post,
|
||||
num_prefills=num_prefills_post,
|
||||
num_decode_tokens=num_decode_tokens_post,
|
||||
attn_mask=attn_mask_post,
|
||||
attn_state=attn_state_post,
|
||||
prefill=prefill_post,
|
||||
decode=decode_post,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
return [attention_metadata_pre, attention_metadata_post]
|
||||
Reference in New Issue
Block a user