mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
0
vllm_npu/multistream/__init__.py
Normal file
0
vllm_npu/multistream/__init__.py
Normal file
29
vllm_npu/multistream/base.py
Normal file
29
vllm_npu/multistream/base.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MSEventKey(Enum):
|
||||
ATTN_COM_FINISH = 0
|
||||
ATTN_AR_FINISH = 1
|
||||
FFN_COM_FINISH = 2
|
||||
FFN_AR_FINISH = 3
|
||||
# events for MOE dispatch and combine
|
||||
MOE_BEFORE_COMM = 4
|
||||
MOE_AFTER_COMM = 5
|
||||
# events for shared expert
|
||||
MOE_SE_COMM_FINISH = 6
|
||||
MOE_SE_COMP_FINISH = 7
|
||||
MOE_GATE_FINISH = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MSAttentionMetadataSplitConfig:
|
||||
"""
|
||||
micro batch split config for split attention metadata
|
||||
"""
|
||||
# micro batch num
|
||||
num_micro_batches: int = 2
|
||||
# split micro batches only when total tokens >= min_total_tokens_to_split
|
||||
min_total_tokens_to_split: int = 256
|
||||
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
67
vllm_npu/multistream/context.py
Normal file
67
vllm_npu/multistream/context.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_ms_comm_context: Any = None
|
||||
_cur_micro_batch_num: int = -1
|
||||
_ms_layer_index_context: int = -1
|
||||
_ms_metadata_context: Any = None
|
||||
_ms_attn_metadata_context: Any = None
|
||||
|
||||
|
||||
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
|
||||
attn_metadata: Any):
|
||||
"""
|
||||
set multistream layer context before transformer layers
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = start_layer
|
||||
_ms_metadata_context = ms_metadata
|
||||
_ms_attn_metadata_context = attn_metadata
|
||||
|
||||
|
||||
def reset_multistream_layer_context():
|
||||
"""
|
||||
reset multistream layer context
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = -1
|
||||
_ms_metadata_context = None
|
||||
_ms_attn_metadata_context = None
|
||||
|
||||
|
||||
def get_multistream_layer_context():
|
||||
"""
|
||||
get multistream layer context
|
||||
"""
|
||||
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
|
||||
|
||||
def advance_step_multistream_layer_context():
|
||||
"""
|
||||
advance multistream layer index context
|
||||
"""
|
||||
global _ms_layer_index_context
|
||||
_ms_layer_index_context += 1
|
||||
|
||||
|
||||
def get_multistream_comm_context() -> Any:
|
||||
"""Get the current comm forward context."""
|
||||
return _ms_comm_context
|
||||
|
||||
|
||||
def get_multistream_microbatch_context() -> int:
|
||||
return _cur_micro_batch_num
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_multistream_context(context: Any, micro_batch_num: int):
|
||||
"""A context manager that stores the current comm forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _ms_comm_context, _cur_micro_batch_num
|
||||
_ms_comm_context = context
|
||||
_cur_micro_batch_num = micro_batch_num
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_ms_comm_context = None
|
||||
_cur_micro_batch_num = -1
|
||||
22
vllm_npu/multistream/decorator.py
Normal file
22
vllm_npu/multistream/decorator.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .context import (get_multistream_layer_context,
|
||||
get_multistream_microbatch_context)
|
||||
|
||||
|
||||
# vllm v1 use get_forward_context to get the attn_metadata,
|
||||
# we can use this decorator to update the attn metadata
|
||||
def set_multistream_support():
|
||||
|
||||
def decorator(func):
|
||||
|
||||
def wrapper():
|
||||
context = func()
|
||||
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
micro_batch_num = get_multistream_microbatch_context()
|
||||
if layer_index != -1 and micro_batch_num != -1:
|
||||
context.attn_metadata = attn_metadata[micro_batch_num]
|
||||
return context
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
61
vllm_npu/multistream/layers.py
Normal file
61
vllm_npu/multistream/layers.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from .base import MSEventKey
|
||||
from .context import (get_multistream_layer_context,
|
||||
reset_multistream_layer_context,
|
||||
set_multistream_layer_context)
|
||||
from .metadata import MultiStreamMetadata
|
||||
|
||||
|
||||
class MultiStreamPreTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
intput_tensors: List[torch.Tensor],
|
||||
):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if self.multistream_metadata is None or attn_metadata is None:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
# TODO add attn_metadata management
|
||||
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
|
||||
attn_metadata, intput_tensors)
|
||||
if do_ms:
|
||||
set_multistream_layer_context(
|
||||
self.multistream_metadata.start_layer,
|
||||
self.multistream_metadata, attn_metadata)
|
||||
else:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
|
||||
|
||||
class MultiStreamPostTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(self,
|
||||
input_tensors: Union[List[Tuple[torch.Tensor]],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]]],
|
||||
wait_layer_index: Optional[int] = None):
|
||||
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
|
||||
return input_tensors
|
||||
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
if layer_index >= 0:
|
||||
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
|
||||
self.multistream_metadata.try_wait_event(
|
||||
true_wait_layer,
|
||||
self.multistream_metadata.ms_config.num_micro_batches - 1,
|
||||
MSEventKey.FFN_AR_FINISH)
|
||||
reset_multistream_layer_context()
|
||||
return self.multistream_metadata.merge_micro_batches(input_tensors)
|
||||
182
vllm_npu/multistream/metadata.py
Normal file
182
vllm_npu/multistream/metadata.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_npu.attention.mla_v1 import AscendMLAMetadata
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig, MSEventKey
|
||||
|
||||
|
||||
def split_micro_batches_tensors(input_tensors,
|
||||
split_index: int,
|
||||
keys: Optional[List[str]] = None):
|
||||
if isinstance(input_tensors, list):
|
||||
micro_batches = []
|
||||
for tensor in input_tensors:
|
||||
if tensor is None:
|
||||
micro_batches.append([None, None])
|
||||
else:
|
||||
micro_batches.append(
|
||||
[tensor[:split_index], tensor[split_index:]])
|
||||
return micro_batches
|
||||
elif isinstance(input_tensors, torch.Tensor):
|
||||
return [input_tensors[:split_index], input_tensors[split_index:]]
|
||||
elif input_tensors is None:
|
||||
return [None, None]
|
||||
elif isinstance(input_tensors, Dict):
|
||||
assert keys is not None
|
||||
micro_batches_pre = {}
|
||||
for key in keys:
|
||||
micro_batches_pre[key] = input_tensors[key][:split_index]
|
||||
micro_batches_post = {}
|
||||
for key in keys:
|
||||
micro_batches_post[key] = input_tensors[key][split_index:]
|
||||
return [micro_batches_pre, micro_batches_post]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamStepMetadata:
|
||||
comm_stream: torch.npu.Stream = None
|
||||
before_comm_event: torch.npu.Event = None
|
||||
after_comm_event: torch.npu.Event = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamConfig:
|
||||
"""Controls the behavior of multi-stream models."""
|
||||
min_total_tokens_to_split: int = 256
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
num_micro_batches: int = 2
|
||||
imbalance_ratio: float = 0.1
|
||||
|
||||
|
||||
class MultiStreamMetadata:
|
||||
# direct stream
|
||||
calculate_stream = None
|
||||
# delay stream
|
||||
communicate_stream = None
|
||||
# events
|
||||
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
|
||||
# multi-stream-flag
|
||||
enable_multi_stream: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculate_stream: torch.npu.Stream,
|
||||
communicate_stream: torch.npu.Stream,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
event_keys: List[MSEventKey],
|
||||
multistream_config: Optional[MultiStreamConfig],
|
||||
causal_lm: bool = True,
|
||||
):
|
||||
self.calculate_stream = calculate_stream
|
||||
self.communicate_stream = communicate_stream
|
||||
self.start_layer = start_layer
|
||||
self.end_layer = end_layer
|
||||
self.ms_config = multistream_config
|
||||
self.causal_lm = causal_lm
|
||||
self._build_events(event_keys)
|
||||
self._build_ms_split_config()
|
||||
|
||||
def _build_events(self, event_keys):
|
||||
if self.ms_config is not None:
|
||||
for i in range(self.start_layer - 1, self.end_layer):
|
||||
self.ms_events[i] = {}
|
||||
for j in range(self.ms_config.num_micro_batches):
|
||||
self.ms_events[i][j] = {}
|
||||
for key in event_keys:
|
||||
self.ms_events[i][j][key] = torch.npu.Event()
|
||||
|
||||
def _build_ms_split_config(self):
|
||||
if self.ms_config is not None:
|
||||
self.ms_split_config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=self.ms_config.num_micro_batches,
|
||||
min_total_tokens_to_split=self.ms_config.
|
||||
min_total_tokens_to_split,
|
||||
min_prefill_tokens_to_split=self.ms_config.
|
||||
min_prefill_tokens_to_split,
|
||||
)
|
||||
|
||||
def try_wait_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].wait()
|
||||
|
||||
def try_record_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].record()
|
||||
|
||||
def split_micro_batch(
|
||||
self,
|
||||
attn_metadata: "AscendMLAMetadata",
|
||||
intput_tensors: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
intermediate_tensors_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
|
||||
List[torch.Tensor], List[List[torch.Tensor]]], Union[
|
||||
IntermediateTensors, List[IntermediateTensors]]]:
|
||||
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
|
||||
self.ms_split_config)
|
||||
if len(attn_metadata_list) == 1:
|
||||
return False, attn_metadata_list[
|
||||
0], intput_tensors, intermediate_tensors
|
||||
split_index = attn_metadata_list[0].slot_mapping.shape[0]
|
||||
input_tensors = split_micro_batches_tensors(intput_tensors,
|
||||
split_index)
|
||||
if intermediate_tensors is not None:
|
||||
inter_tensors_list = split_micro_batches_tensors(
|
||||
intermediate_tensors.tensors, split_index,
|
||||
intermediate_tensors_keys)
|
||||
intermediate_tensors = [
|
||||
IntermediateTensors(inter_tensors)
|
||||
for inter_tensors in inter_tensors_list
|
||||
]
|
||||
return True, attn_metadata_list, input_tensors, intermediate_tensors
|
||||
|
||||
def merge_micro_batches(
|
||||
self, input_tensors: Union[List[torch.Tensor],
|
||||
List[List[torch.Tensor]]]
|
||||
) -> List[torch.Tensor]:
|
||||
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
|
||||
return input_tensors
|
||||
batch: List[Optional[torch.Tensor]] = []
|
||||
for tensors in input_tensors:
|
||||
if tensors is None or tensors[0] is None:
|
||||
batch.append(None)
|
||||
else:
|
||||
batch.append(torch.cat(tensors, dim=0))
|
||||
return batch
|
||||
|
||||
|
||||
def make_multistream_metadata_ds(
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
causal_lm: bool = True,
|
||||
multistream_config: Optional[MultiStreamConfig] = None,
|
||||
):
|
||||
if multistream_config is None:
|
||||
return None
|
||||
event_keylist = [
|
||||
MSEventKey.ATTN_COM_FINISH,
|
||||
MSEventKey.ATTN_AR_FINISH,
|
||||
MSEventKey.FFN_COM_FINISH,
|
||||
MSEventKey.FFN_AR_FINISH,
|
||||
MSEventKey.MOE_BEFORE_COMM,
|
||||
MSEventKey.MOE_AFTER_COMM,
|
||||
MSEventKey.MOE_SE_COMM_FINISH,
|
||||
MSEventKey.MOE_SE_COMP_FINISH,
|
||||
MSEventKey.MOE_GATE_FINISH,
|
||||
]
|
||||
return MultiStreamMetadata(
|
||||
calculate_stream=torch.npu.current_stream(),
|
||||
communicate_stream=torch.npu.Stream(),
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
multistream_config=multistream_config,
|
||||
event_keys=event_keylist,
|
||||
causal_lm=causal_lm,
|
||||
)
|
||||
247
vllm_npu/multistream/ms_split.py
Normal file
247
vllm_npu/multistream/ms_split.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig
|
||||
|
||||
|
||||
def compute_split_seq_index(
|
||||
query_lens: Optional[list[int]],
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int,
|
||||
imbalance_ratio: float = 0.1,
|
||||
) -> list[int]:
|
||||
if attn_state != AscendAttentionState.DecodeOnly:
|
||||
assert query_lens is not None
|
||||
total_tokens = sum(query_lens)
|
||||
# the first index in last split
|
||||
tokens, split_index = 0, 0
|
||||
for value in query_lens:
|
||||
tokens += value
|
||||
split_index += 1
|
||||
if tokens >= total_tokens // 2:
|
||||
# check the current split index
|
||||
if abs(tokens -
|
||||
total_tokens // 2) < total_tokens * imbalance_ratio:
|
||||
return [tokens, split_index]
|
||||
# check the previous split index
|
||||
elif abs(tokens - total_tokens // 2 -
|
||||
value) < total_tokens * imbalance_ratio:
|
||||
return [tokens - value, split_index - 1]
|
||||
# fail to split if it is imbalanced
|
||||
# TODO: split tokens in seq
|
||||
else:
|
||||
return [0, 0]
|
||||
else:
|
||||
tokens = num_tokens // 2
|
||||
return [tokens, tokens]
|
||||
return [0, 0]
|
||||
|
||||
|
||||
def split_attn_tensor_type(
|
||||
input_tensor: torch.Tensor,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [input_tensor[:index], input_tensor[index:]]
|
||||
|
||||
|
||||
def split_attn_int_type(
|
||||
var: int,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [min(var, index), max(var - index, 0)]
|
||||
|
||||
|
||||
def model_input_split_v1_mla_attn(
|
||||
attn_metadata,
|
||||
_metadata_cls,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> List[Any]:
|
||||
assert 0 < ms_split_config.num_micro_batches < 3
|
||||
if attn_metadata is None:
|
||||
return [attn_metadata]
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
||||
attn_metadata.attn_state,
|
||||
attn_metadata.num_decode_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
attn_metadata.query_lens):
|
||||
return [attn_metadata]
|
||||
|
||||
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
||||
dtype=int)
|
||||
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
||||
if attn_metadata.num_prefills > 0:
|
||||
prefill_query_start_loc = np.zeros(
|
||||
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
||||
np.cumsum(attn_metadata.prefill.query_lens,
|
||||
out=prefill_query_start_loc[1:])
|
||||
|
||||
# split attn metadata
|
||||
[slot_mapping_pre,
|
||||
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
||||
token_index)
|
||||
[num_decodes_pre,
|
||||
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
||||
seq_index)
|
||||
[num_decode_tokens_pre, num_decode_tokens_post
|
||||
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
||||
[num_prefills_pre, num_prefills_post
|
||||
] = split_attn_int_type(attn_metadata.num_prefills,
|
||||
max(0, seq_index - attn_metadata.num_decodes))
|
||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
|
||||
query_start_loc_pre = query_start_loc_post = None
|
||||
if attn_metadata.query_start_loc is not None:
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
[block_table_pre,
|
||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||
seq_index)
|
||||
assert attn_metadata.attn_mask is not None
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
# should be none in decode only state
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
# chunked prefill
|
||||
if num_prefills_pre > 0:
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
||||
seq_lens_pre)].contiguous()
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
else:
|
||||
attn_state_pre = AscendAttentionState.DecodeOnly
|
||||
attn_mask_pre = None
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
from vllm_npu.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAPrefillMetadata)
|
||||
if num_prefills_pre > 0:
|
||||
# split metadata.prefill
|
||||
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
||||
attn_metadata.prefill.input_positions,
|
||||
token_index - attn_metadata.num_decode_tokens)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
[prefill_query_lens_pre, prefill_query_lens_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
||||
seq_index
|
||||
+
|
||||
1 -
|
||||
attn_metadata
|
||||
.
|
||||
num_decodes]
|
||||
prefill_query_start_loc_post = deepcopy(
|
||||
attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes:]
|
||||
) - attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes]
|
||||
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
||||
context_len_post = seq_lens_post
|
||||
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
||||
prefill_max_query_len_post = max(prefill_query_lens_post)
|
||||
prefill_pre = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_pre,
|
||||
query_lens=prefill_query_lens_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=prefill_query_start_loc_pre,
|
||||
input_positions=input_positions_pre,
|
||||
context_lens=context_len_pre,
|
||||
block_table=block_tables_pre,
|
||||
max_query_len=prefill_max_query_len_pre,
|
||||
max_seq_lens=context_len_pre.max().item(),
|
||||
)
|
||||
prefill_post = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_post,
|
||||
query_lens=prefill_query_lens_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=prefill_query_start_loc_post,
|
||||
input_positions=input_positions_post,
|
||||
context_lens=context_len_post,
|
||||
block_table=block_tables_post,
|
||||
max_query_len=prefill_max_query_len_post,
|
||||
max_seq_lens=context_len_post.max().item(),
|
||||
)
|
||||
decode_pre = attn_metadata.decode
|
||||
decode_post = None
|
||||
else:
|
||||
# prefill is None, split metadata.decode
|
||||
[input_positions_pre, input_positions_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
||||
token_index)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
||||
seq_index)
|
||||
[decode_seq_lens_pre,
|
||||
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
decode_pre = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_pre,
|
||||
block_table=block_tables_pre,
|
||||
seq_lens=decode_seq_lens_pre,
|
||||
max_seq_lens=max(decode_seq_lens_pre),
|
||||
seq_lens_list=decode_seq_lens_pre.tolist(),
|
||||
)
|
||||
decode_post = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_post,
|
||||
block_table=block_tables_post,
|
||||
seq_lens=decode_seq_lens_post,
|
||||
max_seq_lens=max(decode_seq_lens_post),
|
||||
seq_lens_list=decode_seq_lens_post.tolist(),
|
||||
)
|
||||
prefill_pre = None
|
||||
prefill_post = attn_metadata.prefill
|
||||
# construct metadata
|
||||
from vllm_npu.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
attention_metadata_pre = _metadata_cls(
|
||||
num_actual_tokens=token_index,
|
||||
num_input_tokens=token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=query_start_loc_pre,
|
||||
block_tables=block_table_pre,
|
||||
num_decodes=num_decodes_pre,
|
||||
num_prefills=num_prefills_pre,
|
||||
num_decode_tokens=num_decode_tokens_pre,
|
||||
attn_state=attn_state_pre,
|
||||
attn_mask=attn_mask_pre,
|
||||
prefill=prefill_pre,
|
||||
decode=decode_pre,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
attention_metadata_post = _metadata_cls(
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=query_start_loc_post,
|
||||
block_tables=block_table_post,
|
||||
num_decodes=num_decodes_post,
|
||||
num_prefills=num_prefills_post,
|
||||
num_decode_tokens=num_decode_tokens_post,
|
||||
attn_mask=attn_mask_post,
|
||||
attn_state=attn_state_post,
|
||||
prefill=prefill_post,
|
||||
decode=decode_post,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
return [attention_metadata_pre, attention_metadata_post]
|
||||
Reference in New Issue
Block a user