mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
248 lines
11 KiB
Python
248 lines
11 KiB
Python
from copy import deepcopy
|
|
from typing import Any, List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
|
|
|
from .base import MSAttentionMetadataSplitConfig
|
|
|
|
|
|
def compute_split_seq_index(
|
|
query_lens: Optional[list[int]],
|
|
attn_state: AscendAttentionState,
|
|
num_tokens: int,
|
|
imbalance_ratio: float = 0.1,
|
|
) -> list[int]:
|
|
if attn_state != AscendAttentionState.DecodeOnly:
|
|
assert query_lens is not None
|
|
total_tokens = sum(query_lens)
|
|
# the first index in last split
|
|
tokens, split_index = 0, 0
|
|
for value in query_lens:
|
|
tokens += value
|
|
split_index += 1
|
|
if tokens >= total_tokens // 2:
|
|
# check the current split index
|
|
if abs(tokens -
|
|
total_tokens // 2) < total_tokens * imbalance_ratio:
|
|
return [tokens, split_index]
|
|
# check the previous split index
|
|
elif abs(tokens - total_tokens // 2 -
|
|
value) < total_tokens * imbalance_ratio:
|
|
return [tokens - value, split_index - 1]
|
|
# fail to split if it is imbalanced
|
|
# TODO: split tokens in seq
|
|
else:
|
|
return [0, 0]
|
|
else:
|
|
tokens = num_tokens // 2
|
|
return [tokens, tokens]
|
|
return [0, 0]
|
|
|
|
|
|
def split_attn_tensor_type(
|
|
input_tensor: torch.Tensor,
|
|
index: int,
|
|
) -> List[torch.Tensor]:
|
|
return [input_tensor[:index], input_tensor[index:]]
|
|
|
|
|
|
def split_attn_int_type(
|
|
var: int,
|
|
index: int,
|
|
) -> List[torch.Tensor]:
|
|
return [min(var, index), max(var - index, 0)]
|
|
|
|
|
|
def model_input_split_v1_mla_attn(
|
|
attn_metadata,
|
|
_metadata_cls,
|
|
ms_split_config: MSAttentionMetadataSplitConfig,
|
|
) -> List[Any]:
|
|
assert 0 < ms_split_config.num_micro_batches < 3
|
|
if attn_metadata is None:
|
|
return [attn_metadata]
|
|
[token_index,
|
|
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
|
attn_metadata.attn_state,
|
|
attn_metadata.num_decode_tokens)
|
|
if token_index == 0 or seq_index == 0 or seq_index == len(
|
|
attn_metadata.query_lens):
|
|
return [attn_metadata]
|
|
|
|
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
|
dtype=int)
|
|
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
|
if attn_metadata.num_prefills > 0:
|
|
prefill_query_start_loc = np.zeros(
|
|
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
|
np.cumsum(attn_metadata.prefill.query_lens,
|
|
out=prefill_query_start_loc[1:])
|
|
|
|
# split attn metadata
|
|
[slot_mapping_pre,
|
|
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
|
token_index)
|
|
[num_decodes_pre,
|
|
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
|
seq_index)
|
|
[num_decode_tokens_pre, num_decode_tokens_post
|
|
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
|
[num_prefills_pre, num_prefills_post
|
|
] = split_attn_int_type(attn_metadata.num_prefills,
|
|
max(0, seq_index - attn_metadata.num_decodes))
|
|
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
|
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
|
|
|
query_start_loc_pre = query_start_loc_post = None
|
|
if attn_metadata.query_start_loc is not None:
|
|
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
|
query_start_loc_post = deepcopy(
|
|
attn_metadata.query_start_loc[seq_index:]
|
|
) - attn_metadata.query_start_loc[seq_index]
|
|
[block_table_pre,
|
|
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
|
seq_index)
|
|
assert attn_metadata.attn_mask is not None
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
|
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
|
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
|
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
# should be none in decode only state
|
|
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
|
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
|
else:
|
|
# chunked prefill
|
|
if num_prefills_pre > 0:
|
|
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
|
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
|
seq_lens_pre)].contiguous()
|
|
attn_state_post = AscendAttentionState.ChunkedPrefill
|
|
attn_mask_post = attn_metadata.attn_mask[
|
|
token_index:, :max(seq_lens_post)].contiguous()
|
|
else:
|
|
attn_state_pre = AscendAttentionState.DecodeOnly
|
|
attn_mask_pre = None
|
|
attn_state_post = AscendAttentionState.ChunkedPrefill
|
|
attn_mask_post = attn_metadata.attn_mask[
|
|
token_index:, :max(seq_lens_post)].contiguous()
|
|
from vllm_npu.attention.mla_v1 import (AscendMLADecodeMetadata,
|
|
AscendMLAPrefillMetadata)
|
|
if num_prefills_pre > 0:
|
|
# split metadata.prefill
|
|
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
|
attn_metadata.prefill.input_positions,
|
|
token_index - attn_metadata.num_decode_tokens)
|
|
[block_tables_pre, block_tables_post
|
|
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
|
seq_index - attn_metadata.num_decodes)
|
|
[prefill_query_lens_pre, prefill_query_lens_post
|
|
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
|
seq_index - attn_metadata.num_decodes)
|
|
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
|
seq_index
|
|
+
|
|
1 -
|
|
attn_metadata
|
|
.
|
|
num_decodes]
|
|
prefill_query_start_loc_post = deepcopy(
|
|
attn_metadata.prefill.query_start_loc[seq_index -
|
|
attn_metadata.num_decodes:]
|
|
) - attn_metadata.prefill.query_start_loc[seq_index -
|
|
attn_metadata.num_decodes]
|
|
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
|
context_len_post = seq_lens_post
|
|
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
|
prefill_max_query_len_post = max(prefill_query_lens_post)
|
|
prefill_pre = AscendMLAPrefillMetadata(
|
|
attn_mask=attn_mask_pre,
|
|
query_lens=prefill_query_lens_pre,
|
|
seq_lens=seq_lens_pre,
|
|
query_start_loc=prefill_query_start_loc_pre,
|
|
input_positions=input_positions_pre,
|
|
context_lens=context_len_pre,
|
|
block_table=block_tables_pre,
|
|
max_query_len=prefill_max_query_len_pre,
|
|
max_seq_lens=context_len_pre.max().item(),
|
|
)
|
|
prefill_post = AscendMLAPrefillMetadata(
|
|
attn_mask=attn_mask_post,
|
|
query_lens=prefill_query_lens_post,
|
|
seq_lens=seq_lens_post,
|
|
query_start_loc=prefill_query_start_loc_post,
|
|
input_positions=input_positions_post,
|
|
context_lens=context_len_post,
|
|
block_table=block_tables_post,
|
|
max_query_len=prefill_max_query_len_post,
|
|
max_seq_lens=context_len_post.max().item(),
|
|
)
|
|
decode_pre = attn_metadata.decode
|
|
decode_post = None
|
|
else:
|
|
# prefill is None, split metadata.decode
|
|
[input_positions_pre, input_positions_post
|
|
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
|
token_index)
|
|
[block_tables_pre, block_tables_post
|
|
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
|
seq_index)
|
|
[decode_seq_lens_pre,
|
|
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
|
decode_pre = AscendMLADecodeMetadata(
|
|
input_positions=input_positions_pre,
|
|
block_table=block_tables_pre,
|
|
seq_lens=decode_seq_lens_pre,
|
|
max_seq_lens=max(decode_seq_lens_pre),
|
|
seq_lens_list=decode_seq_lens_pre.tolist(),
|
|
)
|
|
decode_post = AscendMLADecodeMetadata(
|
|
input_positions=input_positions_post,
|
|
block_table=block_tables_post,
|
|
seq_lens=decode_seq_lens_post,
|
|
max_seq_lens=max(decode_seq_lens_post),
|
|
seq_lens_list=decode_seq_lens_post.tolist(),
|
|
)
|
|
prefill_pre = None
|
|
prefill_post = attn_metadata.prefill
|
|
# construct metadata
|
|
from vllm_npu.attention.mla_v1 import AscendMLAPrefillMetadata
|
|
attention_metadata_pre = _metadata_cls(
|
|
num_actual_tokens=token_index,
|
|
num_input_tokens=token_index,
|
|
head_dim=attn_metadata.head_dim,
|
|
slot_mapping=slot_mapping_pre,
|
|
seq_lens=seq_lens_pre,
|
|
query_start_loc=query_start_loc_pre,
|
|
block_tables=block_table_pre,
|
|
num_decodes=num_decodes_pre,
|
|
num_prefills=num_prefills_pre,
|
|
num_decode_tokens=num_decode_tokens_pre,
|
|
attn_state=attn_state_pre,
|
|
attn_mask=attn_mask_pre,
|
|
prefill=prefill_pre,
|
|
decode=decode_pre,
|
|
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
|
)
|
|
attention_metadata_post = _metadata_cls(
|
|
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
|
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
|
head_dim=attn_metadata.head_dim,
|
|
slot_mapping=slot_mapping_post,
|
|
seq_lens=seq_lens_post,
|
|
query_start_loc=query_start_loc_post,
|
|
block_tables=block_table_post,
|
|
num_decodes=num_decodes_post,
|
|
num_prefills=num_prefills_post,
|
|
num_decode_tokens=num_decode_tokens_post,
|
|
attn_mask=attn_mask_post,
|
|
attn_state=attn_state_post,
|
|
prefill=prefill_post,
|
|
decode=decode_post,
|
|
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
|
)
|
|
return [attention_metadata_pre, attention_metadata_post]
|