mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
183 lines
6.6 KiB
Python
183 lines
6.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from vllm_npu.attention.mla_v1 import AscendMLAMetadata
|
|
|
|
from .base import MSAttentionMetadataSplitConfig, MSEventKey
|
|
|
|
|
|
def split_micro_batches_tensors(input_tensors,
|
|
split_index: int,
|
|
keys: Optional[List[str]] = None):
|
|
if isinstance(input_tensors, list):
|
|
micro_batches = []
|
|
for tensor in input_tensors:
|
|
if tensor is None:
|
|
micro_batches.append([None, None])
|
|
else:
|
|
micro_batches.append(
|
|
[tensor[:split_index], tensor[split_index:]])
|
|
return micro_batches
|
|
elif isinstance(input_tensors, torch.Tensor):
|
|
return [input_tensors[:split_index], input_tensors[split_index:]]
|
|
elif input_tensors is None:
|
|
return [None, None]
|
|
elif isinstance(input_tensors, Dict):
|
|
assert keys is not None
|
|
micro_batches_pre = {}
|
|
for key in keys:
|
|
micro_batches_pre[key] = input_tensors[key][:split_index]
|
|
micro_batches_post = {}
|
|
for key in keys:
|
|
micro_batches_post[key] = input_tensors[key][split_index:]
|
|
return [micro_batches_pre, micro_batches_post]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class MultiStreamStepMetadata:
|
|
comm_stream: torch.npu.Stream = None
|
|
before_comm_event: torch.npu.Event = None
|
|
after_comm_event: torch.npu.Event = None
|
|
|
|
|
|
@dataclass
|
|
class MultiStreamConfig:
|
|
"""Controls the behavior of multi-stream models."""
|
|
min_total_tokens_to_split: int = 256
|
|
min_prefill_tokens_to_split: int = 64
|
|
num_micro_batches: int = 2
|
|
imbalance_ratio: float = 0.1
|
|
|
|
|
|
class MultiStreamMetadata:
|
|
# direct stream
|
|
calculate_stream = None
|
|
# delay stream
|
|
communicate_stream = None
|
|
# events
|
|
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
|
|
# multi-stream-flag
|
|
enable_multi_stream: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
calculate_stream: torch.npu.Stream,
|
|
communicate_stream: torch.npu.Stream,
|
|
start_layer: int,
|
|
end_layer: int,
|
|
event_keys: List[MSEventKey],
|
|
multistream_config: Optional[MultiStreamConfig],
|
|
causal_lm: bool = True,
|
|
):
|
|
self.calculate_stream = calculate_stream
|
|
self.communicate_stream = communicate_stream
|
|
self.start_layer = start_layer
|
|
self.end_layer = end_layer
|
|
self.ms_config = multistream_config
|
|
self.causal_lm = causal_lm
|
|
self._build_events(event_keys)
|
|
self._build_ms_split_config()
|
|
|
|
def _build_events(self, event_keys):
|
|
if self.ms_config is not None:
|
|
for i in range(self.start_layer - 1, self.end_layer):
|
|
self.ms_events[i] = {}
|
|
for j in range(self.ms_config.num_micro_batches):
|
|
self.ms_events[i][j] = {}
|
|
for key in event_keys:
|
|
self.ms_events[i][j][key] = torch.npu.Event()
|
|
|
|
def _build_ms_split_config(self):
|
|
if self.ms_config is not None:
|
|
self.ms_split_config = MSAttentionMetadataSplitConfig(
|
|
num_micro_batches=self.ms_config.num_micro_batches,
|
|
min_total_tokens_to_split=self.ms_config.
|
|
min_total_tokens_to_split,
|
|
min_prefill_tokens_to_split=self.ms_config.
|
|
min_prefill_tokens_to_split,
|
|
)
|
|
|
|
def try_wait_event(self, layer_index: int, micro_batch_index: int,
|
|
event_key: MSEventKey):
|
|
self.ms_events[layer_index][micro_batch_index][event_key].wait()
|
|
|
|
def try_record_event(self, layer_index: int, micro_batch_index: int,
|
|
event_key: MSEventKey):
|
|
self.ms_events[layer_index][micro_batch_index][event_key].record()
|
|
|
|
def split_micro_batch(
|
|
self,
|
|
attn_metadata: "AscendMLAMetadata",
|
|
intput_tensors: List[torch.Tensor],
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
intermediate_tensors_keys: Optional[List[str]] = None,
|
|
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
|
|
List[torch.Tensor], List[List[torch.Tensor]]], Union[
|
|
IntermediateTensors, List[IntermediateTensors]]]:
|
|
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
|
|
self.ms_split_config)
|
|
if len(attn_metadata_list) == 1:
|
|
return False, attn_metadata_list[
|
|
0], intput_tensors, intermediate_tensors
|
|
split_index = attn_metadata_list[0].slot_mapping.shape[0]
|
|
input_tensors = split_micro_batches_tensors(intput_tensors,
|
|
split_index)
|
|
if intermediate_tensors is not None:
|
|
inter_tensors_list = split_micro_batches_tensors(
|
|
intermediate_tensors.tensors, split_index,
|
|
intermediate_tensors_keys)
|
|
intermediate_tensors = [
|
|
IntermediateTensors(inter_tensors)
|
|
for inter_tensors in inter_tensors_list
|
|
]
|
|
return True, attn_metadata_list, input_tensors, intermediate_tensors
|
|
|
|
def merge_micro_batches(
|
|
self, input_tensors: Union[List[torch.Tensor],
|
|
List[List[torch.Tensor]]]
|
|
) -> List[torch.Tensor]:
|
|
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
|
|
return input_tensors
|
|
batch: List[Optional[torch.Tensor]] = []
|
|
for tensors in input_tensors:
|
|
if tensors is None or tensors[0] is None:
|
|
batch.append(None)
|
|
else:
|
|
batch.append(torch.cat(tensors, dim=0))
|
|
return batch
|
|
|
|
|
|
def make_multistream_metadata_ds(
|
|
start_layer: int,
|
|
end_layer: int,
|
|
causal_lm: bool = True,
|
|
multistream_config: Optional[MultiStreamConfig] = None,
|
|
):
|
|
if multistream_config is None:
|
|
return None
|
|
event_keylist = [
|
|
MSEventKey.ATTN_COM_FINISH,
|
|
MSEventKey.ATTN_AR_FINISH,
|
|
MSEventKey.FFN_COM_FINISH,
|
|
MSEventKey.FFN_AR_FINISH,
|
|
MSEventKey.MOE_BEFORE_COMM,
|
|
MSEventKey.MOE_AFTER_COMM,
|
|
MSEventKey.MOE_SE_COMM_FINISH,
|
|
MSEventKey.MOE_SE_COMP_FINISH,
|
|
MSEventKey.MOE_GATE_FINISH,
|
|
]
|
|
return MultiStreamMetadata(
|
|
calculate_stream=torch.npu.current_stream(),
|
|
communicate_stream=torch.npu.Stream(),
|
|
start_layer=start_layer,
|
|
end_layer=end_layer,
|
|
multistream_config=multistream_config,
|
|
event_keys=event_keylist,
|
|
causal_lm=causal_lm,
|
|
)
|