This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from enum import Enum
class MSEventKey(Enum):
ATTN_COM_FINISH = 0
ATTN_AR_FINISH = 1
FFN_COM_FINISH = 2
FFN_AR_FINISH = 3
# events for MOE dispatch and combine
MOE_BEFORE_COMM = 4
MOE_AFTER_COMM = 5
# events for shared expert
MOE_SE_COMM_FINISH = 6
MOE_SE_COMP_FINISH = 7
MOE_GATE_FINISH = 8
@dataclass
class MSAttentionMetadataSplitConfig:
"""
micro batch split config for split attention metadata
"""
# micro batch num
num_micro_batches: int = 2
# split micro batches only when total tokens >= min_total_tokens_to_split
min_total_tokens_to_split: int = 256
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
min_prefill_tokens_to_split: int = 64

View File

@@ -0,0 +1,67 @@
from contextlib import contextmanager
from typing import Any
_ms_comm_context: Any = None
_cur_micro_batch_num: int = -1
_ms_layer_index_context: int = -1
_ms_metadata_context: Any = None
_ms_attn_metadata_context: Any = None
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
attn_metadata: Any):
"""
set multistream layer context before transformer layers
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = start_layer
_ms_metadata_context = ms_metadata
_ms_attn_metadata_context = attn_metadata
def reset_multistream_layer_context():
"""
reset multistream layer context
"""
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
_ms_layer_index_context = -1
_ms_metadata_context = None
_ms_attn_metadata_context = None
def get_multistream_layer_context():
"""
get multistream layer context
"""
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
def advance_step_multistream_layer_context():
"""
advance multistream layer index context
"""
global _ms_layer_index_context
_ms_layer_index_context += 1
def get_multistream_comm_context() -> Any:
"""Get the current comm forward context."""
return _ms_comm_context
def get_multistream_microbatch_context() -> int:
return _cur_micro_batch_num
@contextmanager
def set_multistream_context(context: Any, micro_batch_num: int):
"""A context manager that stores the current comm forward context,
can be attention metadata, etc."""
global _ms_comm_context, _cur_micro_batch_num
_ms_comm_context = context
_cur_micro_batch_num = micro_batch_num
try:
yield
finally:
_ms_comm_context = None
_cur_micro_batch_num = -1

View File

@@ -0,0 +1,22 @@
from .context import (get_multistream_layer_context,
get_multistream_microbatch_context)
# vllm v1 use get_forward_context to get the attn_metadata,
# we can use this decorator to update the attn metadata
def set_multistream_support():
def decorator(func):
def wrapper():
context = func()
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
)
micro_batch_num = get_multistream_microbatch_context()
if layer_index != -1 and micro_batch_num != -1:
context.attn_metadata = attn_metadata[micro_batch_num]
return context
return wrapper
return decorator

View File

@@ -0,0 +1,61 @@
from typing import List, Optional, Tuple, Union
import torch
from vllm.forward_context import get_forward_context
from .base import MSEventKey
from .context import (get_multistream_layer_context,
reset_multistream_layer_context,
set_multistream_layer_context)
from .metadata import MultiStreamMetadata
class MultiStreamPreTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(
self,
intput_tensors: List[torch.Tensor],
):
attn_metadata = get_forward_context().attn_metadata
if self.multistream_metadata is None or attn_metadata is None:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
# TODO add attn_metadata management
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
attn_metadata, intput_tensors)
if do_ms:
set_multistream_layer_context(
self.multistream_metadata.start_layer,
self.multistream_metadata, attn_metadata)
else:
set_multistream_layer_context(-1, None, None)
return attn_metadata, intput_tensors
class MultiStreamPostTransformerLayer(torch.nn.Module):
def __init__(self, multistream_metadata: MultiStreamMetadata):
super().__init__()
self.multistream_metadata = multistream_metadata
def forward(self,
input_tensors: Union[List[Tuple[torch.Tensor]],
List[torch.Tensor],
List[List[torch.Tensor]]],
wait_layer_index: Optional[int] = None):
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
return input_tensors
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
)
if layer_index >= 0:
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
self.multistream_metadata.try_wait_event(
true_wait_layer,
self.multistream_metadata.ms_config.num_micro_batches - 1,
MSEventKey.FFN_AR_FINISH)
reset_multistream_layer_context()
return self.multistream_metadata.merge_micro_batches(input_tensors)

View File

@@ -0,0 +1,182 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
from vllm.sequence import IntermediateTensors
from vllm_npu.attention.mla_v1 import AscendMLAMetadata
from .base import MSAttentionMetadataSplitConfig, MSEventKey
def split_micro_batches_tensors(input_tensors,
split_index: int,
keys: Optional[List[str]] = None):
if isinstance(input_tensors, list):
micro_batches = []
for tensor in input_tensors:
if tensor is None:
micro_batches.append([None, None])
else:
micro_batches.append(
[tensor[:split_index], tensor[split_index:]])
return micro_batches
elif isinstance(input_tensors, torch.Tensor):
return [input_tensors[:split_index], input_tensors[split_index:]]
elif input_tensors is None:
return [None, None]
elif isinstance(input_tensors, Dict):
assert keys is not None
micro_batches_pre = {}
for key in keys:
micro_batches_pre[key] = input_tensors[key][:split_index]
micro_batches_post = {}
for key in keys:
micro_batches_post[key] = input_tensors[key][split_index:]
return [micro_batches_pre, micro_batches_post]
else:
raise NotImplementedError
@dataclass
class MultiStreamStepMetadata:
comm_stream: torch.npu.Stream = None
before_comm_event: torch.npu.Event = None
after_comm_event: torch.npu.Event = None
@dataclass
class MultiStreamConfig:
"""Controls the behavior of multi-stream models."""
min_total_tokens_to_split: int = 256
min_prefill_tokens_to_split: int = 64
num_micro_batches: int = 2
imbalance_ratio: float = 0.1
class MultiStreamMetadata:
# direct stream
calculate_stream = None
# delay stream
communicate_stream = None
# events
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
# multi-stream-flag
enable_multi_stream: bool = False
def __init__(
self,
calculate_stream: torch.npu.Stream,
communicate_stream: torch.npu.Stream,
start_layer: int,
end_layer: int,
event_keys: List[MSEventKey],
multistream_config: Optional[MultiStreamConfig],
causal_lm: bool = True,
):
self.calculate_stream = calculate_stream
self.communicate_stream = communicate_stream
self.start_layer = start_layer
self.end_layer = end_layer
self.ms_config = multistream_config
self.causal_lm = causal_lm
self._build_events(event_keys)
self._build_ms_split_config()
def _build_events(self, event_keys):
if self.ms_config is not None:
for i in range(self.start_layer - 1, self.end_layer):
self.ms_events[i] = {}
for j in range(self.ms_config.num_micro_batches):
self.ms_events[i][j] = {}
for key in event_keys:
self.ms_events[i][j][key] = torch.npu.Event()
def _build_ms_split_config(self):
if self.ms_config is not None:
self.ms_split_config = MSAttentionMetadataSplitConfig(
num_micro_batches=self.ms_config.num_micro_batches,
min_total_tokens_to_split=self.ms_config.
min_total_tokens_to_split,
min_prefill_tokens_to_split=self.ms_config.
min_prefill_tokens_to_split,
)
def try_wait_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].wait()
def try_record_event(self, layer_index: int, micro_batch_index: int,
event_key: MSEventKey):
self.ms_events[layer_index][micro_batch_index][event_key].record()
def split_micro_batch(
self,
attn_metadata: "AscendMLAMetadata",
intput_tensors: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors_keys: Optional[List[str]] = None,
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
List[torch.Tensor], List[List[torch.Tensor]]], Union[
IntermediateTensors, List[IntermediateTensors]]]:
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
self.ms_split_config)
if len(attn_metadata_list) == 1:
return False, attn_metadata_list[
0], intput_tensors, intermediate_tensors
split_index = attn_metadata_list[0].slot_mapping.shape[0]
input_tensors = split_micro_batches_tensors(intput_tensors,
split_index)
if intermediate_tensors is not None:
inter_tensors_list = split_micro_batches_tensors(
intermediate_tensors.tensors, split_index,
intermediate_tensors_keys)
intermediate_tensors = [
IntermediateTensors(inter_tensors)
for inter_tensors in inter_tensors_list
]
return True, attn_metadata_list, input_tensors, intermediate_tensors
def merge_micro_batches(
self, input_tensors: Union[List[torch.Tensor],
List[List[torch.Tensor]]]
) -> List[torch.Tensor]:
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
return input_tensors
batch: List[Optional[torch.Tensor]] = []
for tensors in input_tensors:
if tensors is None or tensors[0] is None:
batch.append(None)
else:
batch.append(torch.cat(tensors, dim=0))
return batch
def make_multistream_metadata_ds(
start_layer: int,
end_layer: int,
causal_lm: bool = True,
multistream_config: Optional[MultiStreamConfig] = None,
):
if multistream_config is None:
return None
event_keylist = [
MSEventKey.ATTN_COM_FINISH,
MSEventKey.ATTN_AR_FINISH,
MSEventKey.FFN_COM_FINISH,
MSEventKey.FFN_AR_FINISH,
MSEventKey.MOE_BEFORE_COMM,
MSEventKey.MOE_AFTER_COMM,
MSEventKey.MOE_SE_COMM_FINISH,
MSEventKey.MOE_SE_COMP_FINISH,
MSEventKey.MOE_GATE_FINISH,
]
return MultiStreamMetadata(
calculate_stream=torch.npu.current_stream(),
communicate_stream=torch.npu.Stream(),
start_layer=start_layer,
end_layer=end_layer,
multistream_config=multistream_config,
event_keys=event_keylist,
causal_lm=causal_lm,
)

View File

@@ -0,0 +1,247 @@
from copy import deepcopy
from typing import Any, List, Optional
import numpy as np
import torch
from vllm_npu.attention.attention_v1 import AscendAttentionState
from .base import MSAttentionMetadataSplitConfig
def compute_split_seq_index(
query_lens: Optional[list[int]],
attn_state: AscendAttentionState,
num_tokens: int,
imbalance_ratio: float = 0.1,
) -> list[int]:
if attn_state != AscendAttentionState.DecodeOnly:
assert query_lens is not None
total_tokens = sum(query_lens)
# the first index in last split
tokens, split_index = 0, 0
for value in query_lens:
tokens += value
split_index += 1
if tokens >= total_tokens // 2:
# check the current split index
if abs(tokens -
total_tokens // 2) < total_tokens * imbalance_ratio:
return [tokens, split_index]
# check the previous split index
elif abs(tokens - total_tokens // 2 -
value) < total_tokens * imbalance_ratio:
return [tokens - value, split_index - 1]
# fail to split if it is imbalanced
# TODO: split tokens in seq
else:
return [0, 0]
else:
tokens = num_tokens // 2
return [tokens, tokens]
return [0, 0]
def split_attn_tensor_type(
input_tensor: torch.Tensor,
index: int,
) -> List[torch.Tensor]:
return [input_tensor[:index], input_tensor[index:]]
def split_attn_int_type(
var: int,
index: int,
) -> List[torch.Tensor]:
return [min(var, index), max(var - index, 0)]
def model_input_split_v1_mla_attn(
attn_metadata,
_metadata_cls,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> List[Any]:
assert 0 < ms_split_config.num_micro_batches < 3
if attn_metadata is None:
return [attn_metadata]
[token_index,
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
attn_metadata.attn_state,
attn_metadata.num_decode_tokens)
if token_index == 0 or seq_index == 0 or seq_index == len(
attn_metadata.query_lens):
return [attn_metadata]
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
dtype=int)
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
if attn_metadata.num_prefills > 0:
prefill_query_start_loc = np.zeros(
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
np.cumsum(attn_metadata.prefill.query_lens,
out=prefill_query_start_loc[1:])
# split attn metadata
[slot_mapping_pre,
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
token_index)
[num_decodes_pre,
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
seq_index)
[num_decode_tokens_pre, num_decode_tokens_post
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
[num_prefills_pre, num_prefills_post
] = split_attn_int_type(attn_metadata.num_prefills,
max(0, seq_index - attn_metadata.num_decodes))
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
query_start_loc_pre = query_start_loc_post = None
if attn_metadata.query_start_loc is not None:
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
query_start_loc_post = deepcopy(
attn_metadata.query_start_loc[seq_index:]
) - attn_metadata.query_start_loc[seq_index]
[block_table_pre,
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
seq_index)
assert attn_metadata.attn_mask is not None
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
# the attn_mla kernel in torch npu only accept 128*128 attn mask
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = attn_metadata.attn_state
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
# should be none in decode only state
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
else:
# chunked prefill
if num_prefills_pre > 0:
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
seq_lens_pre)].contiguous()
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
else:
attn_state_pre = AscendAttentionState.DecodeOnly
attn_mask_pre = None
attn_state_post = AscendAttentionState.ChunkedPrefill
attn_mask_post = attn_metadata.attn_mask[
token_index:, :max(seq_lens_post)].contiguous()
from vllm_npu.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAPrefillMetadata)
if num_prefills_pre > 0:
# split metadata.prefill
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
attn_metadata.prefill.input_positions,
token_index - attn_metadata.num_decode_tokens)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
seq_index - attn_metadata.num_decodes)
[prefill_query_lens_pre, prefill_query_lens_post
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
seq_index - attn_metadata.num_decodes)
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
seq_index
+
1 -
attn_metadata
.
num_decodes]
prefill_query_start_loc_post = deepcopy(
attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes:]
) - attn_metadata.prefill.query_start_loc[seq_index -
attn_metadata.num_decodes]
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
context_len_post = seq_lens_post
prefill_max_query_len_pre = max(prefill_query_lens_pre)
prefill_max_query_len_post = max(prefill_query_lens_post)
prefill_pre = AscendMLAPrefillMetadata(
attn_mask=attn_mask_pre,
query_lens=prefill_query_lens_pre,
seq_lens=seq_lens_pre,
query_start_loc=prefill_query_start_loc_pre,
input_positions=input_positions_pre,
context_lens=context_len_pre,
block_table=block_tables_pre,
max_query_len=prefill_max_query_len_pre,
max_seq_lens=context_len_pre.max().item(),
)
prefill_post = AscendMLAPrefillMetadata(
attn_mask=attn_mask_post,
query_lens=prefill_query_lens_post,
seq_lens=seq_lens_post,
query_start_loc=prefill_query_start_loc_post,
input_positions=input_positions_post,
context_lens=context_len_post,
block_table=block_tables_post,
max_query_len=prefill_max_query_len_post,
max_seq_lens=context_len_post.max().item(),
)
decode_pre = attn_metadata.decode
decode_post = None
else:
# prefill is None, split metadata.decode
[input_positions_pre, input_positions_post
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
token_index)
[block_tables_pre, block_tables_post
] = split_attn_tensor_type(attn_metadata.decode.block_table,
seq_index)
[decode_seq_lens_pre,
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
decode_pre = AscendMLADecodeMetadata(
input_positions=input_positions_pre,
block_table=block_tables_pre,
seq_lens=decode_seq_lens_pre,
max_seq_lens=max(decode_seq_lens_pre),
seq_lens_list=decode_seq_lens_pre.tolist(),
)
decode_post = AscendMLADecodeMetadata(
input_positions=input_positions_post,
block_table=block_tables_post,
seq_lens=decode_seq_lens_post,
max_seq_lens=max(decode_seq_lens_post),
seq_lens_list=decode_seq_lens_post.tolist(),
)
prefill_pre = None
prefill_post = attn_metadata.prefill
# construct metadata
from vllm_npu.attention.mla_v1 import AscendMLAPrefillMetadata
attention_metadata_pre = _metadata_cls(
num_actual_tokens=token_index,
num_input_tokens=token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_pre,
seq_lens=seq_lens_pre,
query_start_loc=query_start_loc_pre,
block_tables=block_table_pre,
num_decodes=num_decodes_pre,
num_prefills=num_prefills_pre,
num_decode_tokens=num_decode_tokens_pre,
attn_state=attn_state_pre,
attn_mask=attn_mask_pre,
prefill=prefill_pre,
decode=decode_pre,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
num_input_tokens=attn_metadata.num_input_tokens - token_index,
head_dim=attn_metadata.head_dim,
slot_mapping=slot_mapping_post,
seq_lens=seq_lens_post,
query_start_loc=query_start_loc_post,
block_tables=block_table_post,
num_decodes=num_decodes_post,
num_prefills=num_prefills_post,
num_decode_tokens=num_decode_tokens_post,
attn_mask=attn_mask_post,
attn_state=attn_state_post,
prefill=prefill_post,
decode=decode_post,
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
)
return [attention_metadata_pre, attention_metadata_post]