mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
0
vllm_npu/torchair/__init__.py
Normal file
0
vllm_npu/torchair/__init__.py
Normal file
0
vllm_npu/torchair/models/__init__.py
Normal file
0
vllm_npu/torchair/models/__init__.py
Normal file
363
vllm_npu/torchair/models/qwen2.py
Normal file
363
vllm_npu/torchair/models/qwen2.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
|
||||
def all_gather_and_maybe_unpad(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
if pad_size > 0:
|
||||
return hidden_states[:-pad_size, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def maybe_pad_and_reduce_scatter(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
if pad_size > 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2Attention(Qwen2Attention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
trace_flag=False,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
if type(self.rotary_emb) is RotaryEmbedding:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = CustomQwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen2Model(Qwen2Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=decoder_layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = kv_caches[i - self.start_layer] \
|
||||
if kv_caches is not None else None
|
||||
hidden_states, residual = layer(positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen2Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM
|
||||
537
vllm_npu/torchair/models/qwen3_moe.py
Normal file
537
vllm_npu/torchair/models/qwen3_moe.py
Normal file
@@ -0,0 +1,537 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP, Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_npu.torchair.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
from vllm_npu.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = TorchairAscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeAttention(Qwen3MoeAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
@staticmethod
|
||||
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
|
||||
head_dim: int, q_norm, k_norm):
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
|
||||
self.head_dim, self.q_norm, self.k_norm)
|
||||
|
||||
if (self.torchair_graph_enabled and attn_metadata is not None and
|
||||
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
trace_flag=False,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: Optional[VllmConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = CustomQwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
self.use_aclgraph = (vllm_config is not None
|
||||
and vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
if not self.use_aclgraph:
|
||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
||||
self.mlp = CustomSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.
|
||||
enable_sequence_parallelism if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
SupportsPP.__init__(self)
|
||||
SupportsLoRA.__init__(self)
|
||||
MixtureOfExperts.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
218
vllm_npu/torchair/models/torchair_deepseek_mtp.py
Normal file
218
vllm_npu/torchair/models/torchair_deepseek_mtp.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_npu.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class TorchairDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = TorchairDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"shared_head"))
|
||||
self.mtp_block = TorchairDeepseekV2DecoderLayer(
|
||||
config, prefix, model_config, cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||
|
||||
hidden_states, residual = self.mtp_block(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
replace_allreduce=replace_allreduce)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
TorchairDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states))
|
||||
return logits
|
||||
|
||||
|
||||
class TorchairDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
1301
vllm_npu/torchair/models/torchair_deepseek_v2.py
Normal file
1301
vllm_npu/torchair/models/torchair_deepseek_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
28
vllm_npu/torchair/models/torchair_deepseek_v3.py
Normal file
28
vllm_npu/torchair/models/torchair_deepseek_v3.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_npu.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
|
||||
pass
|
||||
1118
vllm_npu/torchair/models/torchair_pangu_moe.py
Normal file
1118
vllm_npu/torchair/models/torchair_pangu_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
0
vllm_npu/torchair/ops/__init__.py
Normal file
0
vllm_npu/torchair/ops/__init__.py
Normal file
120
vllm_npu/torchair/ops/sequence_parallel.py
Normal file
120
vllm_npu/torchair/ops/sequence_parallel.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.platform import NPUPlatform
|
||||
|
||||
|
||||
class MetadataForPadding:
|
||||
|
||||
def __init__(self,
|
||||
padding_flag=False,
|
||||
lengths_sum_padding=0,
|
||||
lengths_sum_unpadding=0,
|
||||
pad_size=0,
|
||||
not_dummy_and_is_prefill=False):
|
||||
self.padding_flag = padding_flag
|
||||
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||
|
||||
self.lengths_sum_padding = lengths_sum_padding
|
||||
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||
self.pad_size = pad_size
|
||||
|
||||
self.tp_size = get_tp_group().world_size
|
||||
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||
|
||||
assert self.lengths_sum_padding % self.tp_size == 0
|
||||
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||
|
||||
self.mc2_mask = torch.zeros(
|
||||
self.lengths_sum_padding,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type,
|
||||
)
|
||||
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||
|
||||
def padding_aligned_reduce_scatter(self,
|
||||
data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||
padded_data, 0)
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
def allgather_unpadding_aligned(self,
|
||||
padded_data: torch.Tensor) -> torch.Tensor:
|
||||
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||
padded_data, 0)
|
||||
if self.padding_flag:
|
||||
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||
else:
|
||||
unpadding_data = padded_data_allgather
|
||||
return unpadding_data
|
||||
|
||||
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||
start = self.tp_rank_in_group * self.slice_size
|
||||
end = start + self.slice_size
|
||||
slice_data = padded_data[start:end]
|
||||
|
||||
return slice_data
|
||||
|
||||
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
# padded_data = data
|
||||
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||
|
||||
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
|
||||
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||
if not enable_sequence_parallelism:
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
|
||||
is_perifll = 0
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata,
|
||||
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||
is_perifll = 1
|
||||
if hasattr(attn_metadata,
|
||||
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||
is_perifll = 1
|
||||
|
||||
if is_perifll:
|
||||
lengths_sum_unpadding = input_ids.shape[0]
|
||||
lengths_sum_padding = (
|
||||
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||
if lengths_sum_unpadding == lengths_sum_padding:
|
||||
padding_flag = False
|
||||
else:
|
||||
padding_flag = True
|
||||
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||
_metadata_for_padding = MetadataForPadding(
|
||||
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||
lengths_sum_padding=lengths_sum_padding,
|
||||
padding_flag=padding_flag,
|
||||
pad_size=pad_size,
|
||||
not_dummy_and_is_prefill=True)
|
||||
|
||||
return _metadata_for_padding
|
||||
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
245
vllm_npu/torchair/ops/shared_weight_layer.py
Normal file
245
vllm_npu/torchair/ops/shared_weight_layer.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerMetadata:
|
||||
"""Metadata for a layer.
|
||||
"""
|
||||
layer: Optional[LinearBase] # The layer object.
|
||||
post_method: Callable[[
|
||||
torch.nn.Module
|
||||
], None] # The `process_weights_after_loading` method from the quant method.
|
||||
weight: torch.Tensor # The weight tensor.
|
||||
window_idx: int # The index of the window.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
num_layers: int
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
return layer_idx % self.group.world_size == self.group.rank_in_group
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
return
|
||||
for layer_idx in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[layer_idx - self.start_layer]
|
||||
is_source = self.is_source(layer_idx)
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
|
||||
if not is_source:
|
||||
layer.weight.set_(torch.empty_like(self.dummy_weight))
|
||||
# Broadcast to get the true weight.
|
||||
dist.broadcast(layer.weight,
|
||||
src=self.group.ranks[layer_idx %
|
||||
self.group.world_size],
|
||||
group=self.group.device_group)
|
||||
assert layer.layer is not None
|
||||
# Call `process_weights_after_loading` from the quant method.
|
||||
layer.post_method(layer.layer)
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
))
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
))
|
||||
# When the layer not intended to be stored in this device, dispose the tensor.
|
||||
if not is_source:
|
||||
dispose_tensor(layer.weight)
|
||||
|
||||
dispose_tensor(self.dummy_weight)
|
||||
|
||||
def reach_layer(self, layer_idx: int):
|
||||
# The index of the layer to be prefetched.
|
||||
next_layer_idx = (layer_idx + self.prefetch_step
|
||||
) % self.num_layers + self.start_layer
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
# Update `window_offset` by rolling one step.
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
|
||||
1)
|
||||
assert window.data_layer_idx != next_layer_idx
|
||||
window.data_layer_idx = next_layer_idx
|
||||
# Start asynchronous broadcast work.
|
||||
window.work = dist.broadcast(
|
||||
next_layer.weight,
|
||||
src=self.group.ranks[next_layer_idx % self.group.world_size],
|
||||
group=self.group.device_group,
|
||||
async_op=True)
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
window.work = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerExternalMetadata:
|
||||
"""External metadata for a layer.
|
||||
"""
|
||||
series: SeriesMetadata
|
||||
layer_idx: int
|
||||
|
||||
|
||||
_series_dict: dict[str, SeriesMetadata] = {}
|
||||
|
||||
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
|
||||
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
layer_idx: int) -> Callable:
|
||||
|
||||
def wrapped_forward(*args, **kwargs):
|
||||
# Wait for the weight.
|
||||
series.wait_weight(layer_idx)
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
return wrapped_forward
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
|
||||
start_layer: The index of the first layer in the series (inclusive).
|
||||
end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
|
||||
layer_idx: The index of the current layer.
|
||||
layer: The linear layer object to register.
|
||||
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
layer_idx: int,
|
||||
layer: LinearBase,
|
||||
prefetch_step: int = 1,
|
||||
):
|
||||
global _series_dict
|
||||
if series_name not in _series_dict:
|
||||
num_layers = end_layer - start_layer
|
||||
assert num_layers > 0
|
||||
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2
|
||||
_series_dict[series_name] = SeriesMetadata(
|
||||
group=group,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
num_layers=num_layers,
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[
|
||||
LayerMetadata(
|
||||
layer=None,
|
||||
post_method=lambda layer: None,
|
||||
weight=torch.empty([]),
|
||||
window_idx=-1,
|
||||
) for _ in range(num_layers)
|
||||
],
|
||||
shared_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
assert layer.quant_method is not None
|
||||
series.layers[layer_idx - start_layer] = LayerMetadata(
|
||||
layer=layer,
|
||||
post_method=layer.quant_method.process_weights_after_loading,
|
||||
weight=layer.weight,
|
||||
window_idx=-1,
|
||||
)
|
||||
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
|
||||
layer.quant_method.process_weights_after_loading = lambda layer: None
|
||||
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
|
||||
if not series.is_source(layer_idx):
|
||||
dispose_tensor(layer.weight)
|
||||
layer.weight.weight_loader = lambda *args, **kwargs: None
|
||||
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
|
||||
global _layer_external_dict
|
||||
_layer_external_dict[id(layer)] = LayerExternalMetadata(
|
||||
series=series,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
37
vllm_npu/torchair/ops/torchair_activation.py
Normal file
37
vllm_npu/torchair/ops/torchair_activation.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""AscendSiluAndMul forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
|
||||
if is_310p():
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
1409
vllm_npu/torchair/ops/torchair_fused_moe.py
Normal file
1409
vllm_npu/torchair/ops/torchair_fused_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
78
vllm_npu/torchair/ops/torchair_layernorm.py
Normal file
78
vllm_npu/torchair/ops/torchair_layernorm.py
Normal file
@@ -0,0 +1,78 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
_original_re_init = RMSNorm.__init__
|
||||
|
||||
|
||||
def torchair_rmsnorm_init_(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
_original_re_init(self, hidden_size, eps, var_hidden_size, has_weight,
|
||||
dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def torchair_rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""AscendRMSNorm forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
365
vllm_npu/torchair/ops/torchair_rotary_embedding.py
Normal file
365
vllm_npu/torchair/ops/torchair_rotary_embedding.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.utils import enable_custom_op, is_310p
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if get_ascend_config(
|
||||
).torchair_graph_config.enabled and not is_qwen_torchair:
|
||||
return self.forward_native(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
)
|
||||
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
||||
cos_cached = cos_cached.to(dtype)
|
||||
sin_cached = sin_cached.to(dtype)
|
||||
cache = torch.cat([freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale],
|
||||
dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
|
||||
def __set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
inv_freq = 1.0 / (self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
|
||||
(1 / self.rotary_dim)))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
|
||||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
|
||||
self.embed = F.embedding
|
||||
|
||||
|
||||
_original_re_init = RotaryEmbedding.__init__
|
||||
|
||||
|
||||
def qwen_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
if get_ascend_config().torchair_graph_config.enabled:
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_position_embeddings,
|
||||
device="npu",
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def rope_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
is_prefill: Optional[bool] = True,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
):
|
||||
if get_ascend_config().torchair_graph_config.enabled \
|
||||
and is_qwen_torchair and not is_prefill:
|
||||
if max_seq_len is not None and torch.gt(max_seq_len,
|
||||
self.max_position_embeddings):
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_seq_len,
|
||||
device=query.device,
|
||||
dtype=torch.float32)
|
||||
|
||||
# bsnd/bnsd
|
||||
if positions is not None:
|
||||
cos = self.embed(positions, self.cos)
|
||||
sin = self.embed(positions, self.sin)
|
||||
self.cos_embed = cos
|
||||
self.sin_embed = sin
|
||||
else:
|
||||
cos = self.cos_embed
|
||||
sin = self.sin_embed
|
||||
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
|
||||
|
||||
cos = cos.unsqueeze(-2).unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2).unsqueeze(-2)
|
||||
|
||||
query = query.unsqueeze(1)
|
||||
key = key.unsqueeze(1)
|
||||
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
|
||||
query, key, cos, sin)
|
||||
return q_embed.flatten(-2), k_embed.flatten(-2)
|
||||
else:
|
||||
return rope_forward_oot(self, positions, query, key, offsets,
|
||||
is_neox_style_override,
|
||||
is_qwen_torchair) # type: ignore
|
||||
|
||||
|
||||
def deepseek_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
|
||||
38
vllm_npu/torchair/ops/torchair_vocab_parallel_embedding.py
Normal file
38
vllm_npu/torchair/ops/torchair_vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
|
||||
|
||||
def vocab_embedding_forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
0
vllm_npu/torchair/quantization/__init__.py
Normal file
0
vllm_npu/torchair/quantization/__init__.py
Normal file
486
vllm_npu/torchair/quantization/torchair_w4a8_dynamic.py
Normal file
486
vllm_npu/torchair/quantization/torchair_w4a8_dynamic.py
Normal file
@@ -0,0 +1,486 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.ascend_forward_context import FusedMoEState
|
||||
from vllm_npu.distributed.parallel_state import get_mc2_group
|
||||
from vllm_npu.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||
from vllm_npu.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
|
||||
if self.new_quant_version:
|
||||
pack_factor = 2
|
||||
actual_output_size = output_size // pack_factor
|
||||
params_dict["weight"] = torch.empty(actual_output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
params_dict["_packed_dim"] = 0
|
||||
params_dict["_packed_factor"] = pack_factor
|
||||
else:
|
||||
params_dict["weight"] = torch.empty(output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
if self.new_quant_version:
|
||||
scale_bias_dim = 16 if layer_type == "row" else 1
|
||||
params_dict["scale_bias"] = torch.empty(output_size,
|
||||
scale_bias_dim,
|
||||
dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor,
|
||||
is_new_quant: bool = False):
|
||||
k, n = weight.shape
|
||||
group_num, n_scale = per_group_scale.shape
|
||||
|
||||
if is_new_quant:
|
||||
n = n * 2
|
||||
|
||||
bias = None
|
||||
if not is_new_quant:
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
is_new_quant=self.new_quant_version,
|
||||
)
|
||||
|
||||
if self.new_quant_version:
|
||||
if hasattr(layer, "scale_bias"):
|
||||
if layer.scale_bias.data.shape[1] == 1:
|
||||
layer.scale_bias.data = layer.scale_bias.data.flatten()
|
||||
else:
|
||||
layer.scale_bias.data = layer.scale_bias.data.contiguous()
|
||||
else:
|
||||
if scale_bias is not None:
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
|
||||
if self.new_quant_version:
|
||||
assert layer.weight.data.shape[-1] % 4 == 0, \
|
||||
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
|
||||
layer.weight.data = layer.weight.data.view(
|
||||
torch.int32).contiguous()
|
||||
else:
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes // 2
|
||||
else:
|
||||
w13_output_size = 2 * intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
w13_output_size,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
w2_output_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
16 // self.tp_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently is 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return torchair_fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
quantized_x_for_share=shared_gate_up,
|
||||
dynamic_scale_for_share=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into layers module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return torchair_fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
n = n * 2
|
||||
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = per_group_scale.shape
|
||||
bias = None
|
||||
if not self.new_quant_version:
|
||||
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
weight_high = weight_high.reshape([group_num, k, n])
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||
torch.float32)
|
||||
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||
scale_fp32_np.dtype = np.uint32
|
||||
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||
dtype=np.uint32)
|
||||
|
||||
sscale_uint64[..., ::2] = scale_fp32_np
|
||||
|
||||
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||
dtype=np.int64).copy()
|
||||
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||
group_num, quantgroup_num, n)
|
||||
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||
return sscale_uint64_tensor, bias
|
||||
|
||||
def update_bias(self, layer, w13_bias, w2_bias):
|
||||
if self.new_quant_version:
|
||||
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
else:
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
1064
vllm_npu/torchair/quantization/torchair_w8a8_dynamic.py
Normal file
1064
vllm_npu/torchair/quantization/torchair_w8a8_dynamic.py
Normal file
File diff suppressed because it is too large
Load Diff
463
vllm_npu/torchair/torchair_attention.py
Normal file
463
vllm_npu/torchair/torchair_attention.py
Normal file
@@ -0,0 +1,463 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from vllm_npu.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_npu.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_npu.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_npu.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d)
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_TORCHAIR"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
|
||||
return AscendAttentionTorchairBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
|
||||
return AscendTorchairMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
||||
return AscendAttentionTorchairMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendDecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendTorchairMetadata(AscendMetadata):
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
self.vllm_config.cache_config.block_size)
|
||||
self.max_blocks = (self.model_config.max_model_len +
|
||||
self.vllm_config.cache_config.block_size -
|
||||
1) // self.vllm_config.cache_config.block_size
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
max_blocks = self.max_blocks
|
||||
|
||||
graph_block_tables = torch.zeros((num_seqs, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[:num_seqs, :
|
||||
num_blocks] = block_tables[:num_seqs, :
|
||||
num_blocks]
|
||||
else:
|
||||
graph_block_tables[:num_seqs, :
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables[:, :max_blocks]
|
||||
|
||||
def build_torchair_graph_dummy(
|
||||
self, common_attn_metadata: TorchairCommonAttentionMetadata
|
||||
) -> AscendTorchairMetadata:
|
||||
device = self.device
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
block_table = torch.zeros((num_reqs, self.max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
input_positions = torch.zeros(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device).long()
|
||||
slot_mapping = torch.full((num_reqs, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc = torch.full((num_reqs, ),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_lens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
decode=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size > -1
|
||||
if common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
num_seqs = len(seq_lens)
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_actual_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * num_reqs_pad_size
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((num_token_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(num_reqs_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs + num_reqs_pad_size, block_table)
|
||||
padding_0 = torch.zeros(num_token_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=None)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
decode=decode_metadata,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.hidden_size = self.num_heads * self.head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendTorchairMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
trace_flag: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0
|
||||
and kv_cache[0].numel() > 0
|
||||
and kv_cache[0].dtype == torch.int8)
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size).fill_(0)
|
||||
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"AscendAttentionTorchairBackendImpl")
|
||||
|
||||
if kv_cache is not None and kv_cache[0].numel() > 0:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = self.scale_tensor + key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
self.key_cache = key_cache
|
||||
self.value_cache = value_cache
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if is_310p():
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
output = output[:num_tokens, :, :]
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
block_table = decode_meta.block_table
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1,
|
||||
self.num_heads * self.head_size).contiguous()
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key_cache,
|
||||
value=value_cache,
|
||||
query_rope=None,
|
||||
key_rope=None,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout='BSH',
|
||||
atten_mask=decode_meta.attn_mask,
|
||||
sparse_mode=0,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode with non-MLA attention backend is still experimental."
|
||||
"v1 scheduler(chunked prefill) is not supported at this moment. Please"
|
||||
"setting 'ascend_scheduler_config':{'enabled':true} in additional_config"
|
||||
"to use ascend scheduler.")
|
||||
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
1310
vllm_npu/torchair/torchair_mla.py
Normal file
1310
vllm_npu/torchair/torchair_mla.py
Normal file
File diff suppressed because it is too large
Load Diff
557
vllm_npu/torchair/torchair_model_runner.py
Normal file
557
vllm_npu/torchair/torchair_model_runner.py
Normal file
@@ -0,0 +1,557 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
# isort: skip_file
|
||||
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.platform import NPUPlatform
|
||||
from vllm_npu.torchair.utils import (
|
||||
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist, converting_weight_acl_format,
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_npu.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p, get_ascend_soc_version,
|
||||
AscendSocVersion)
|
||||
from vllm_npu.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
|
||||
super().__init__(vllm_config, device)
|
||||
if self.speculative_config:
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
None, None, vllm_config, device)
|
||||
|
||||
register_torchair_model()
|
||||
torchair_ops_patch()
|
||||
torchair_quant_method_register()
|
||||
if self.enable_shared_expert_dp:
|
||||
return
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.update_torchair_graph_batch_sizes()
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.vllm_npu_TRACE_RECOMPILES)
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
|
||||
def _may_pad_kv_consumer_num_seq(self):
|
||||
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||
# self.max_num_reqs here is greater than the actual maximum request number
|
||||
if self.decode_token_per_req > 1 and self.is_kv_consumer:
|
||||
# applied only when speculative decoding is active
|
||||
FIA_SEQ_LEN_LIMIT = 16
|
||||
new_max_num_reqs = self.max_num_reqs + math.ceil(
|
||||
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
|
||||
(self.max_num_reqs * self.decode_token_per_req) /
|
||||
(FIA_SEQ_LEN_LIMIT**2))
|
||||
if self.max_num_reqs < new_max_num_reqs:
|
||||
logger.warning(
|
||||
f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
|
||||
def _init_mc2_tokens_capacity(self):
|
||||
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
||||
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
max_num_tokens, tp_size)
|
||||
self.mc2_tokens_capacity = max_graph_batch_size
|
||||
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
|
||||
logger.error(
|
||||
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
|
||||
logger.error(
|
||||
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.enable_shared_expert_dp:
|
||||
# Padding is not required for shared_expert_dp cases in eager mode.
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
max_num_token)
|
||||
num_tokens_across_dp = torch.full((self.dp_size, ),
|
||||
maybe_padded_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
else:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
attn_metadata = super()._build_dummy_attn_metadata(
|
||||
with_prefill, num_reqs, num_tokens, max_query_len,
|
||||
aclgraph_runtime_mode, force_attention)
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
common_attn_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
is_torchair_compile, input_ids,
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||
else:
|
||||
# Only mark static while compiling
|
||||
if is_torchair_compile:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
if self.speculative_config:
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
model_kwargs = {}
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _convert_torch_format(self, kv_cache):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._convert_torch_format(kv_cache)
|
||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||
return kv_cache
|
||||
|
||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
def _capture_model(self):
|
||||
"""Override from NPUModelRunner to use torchair graph capture."""
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._capture_model()
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist (either
|
||||
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
|
||||
# different), we will compile the model twice. The first time is
|
||||
# used to generate the cache, and the second time is used to load the
|
||||
# cache to skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
|
||||
" torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
|
||||
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
|
||||
# runtime errors caused by configuration mismatches in graph mode.
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
logger.info(
|
||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||
0.3 * graph_num, 0.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
else:
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._use_aclgraph()
|
||||
return False
|
||||
|
||||
def _check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||
else:
|
||||
self.graph_pad_size = graph_pad_size
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||
input_ids, positions = super()._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
return input_ids, positions
|
||||
else:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
return input_ids, positions
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._generate_process_reqs_hidden_states(
|
||||
attn_metadata, with_prefill, padded_num_tokens_across_dp,
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
model_kwargs = {
|
||||
"kv_caches": self.kv_caches,
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_npu.patch.platform.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
if self.ascend_config.torchair_graph_config.mode:
|
||||
config.mode = self.ascend_config.torchair_graph_config.mode
|
||||
config.experimental_config.frozen_parameter = \
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
|
||||
tp_size):
|
||||
cur_graph_batch_size = (old_graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
|
||||
# Both adapter multi-dp and FIA operator
|
||||
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
|
||||
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
|
||||
// math.gcd(tp_size, old_graph_batch_size)
|
||||
return cur_graph_batch_size
|
||||
|
||||
def update_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
|
||||
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||
logger.warning(
|
||||
f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}"
|
||||
)
|
||||
elif len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
# on all EP ranks
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel:
|
||||
self._align_graph_size_divisible_by_tp_size()
|
||||
|
||||
def _align_graph_size_divisible_by_tp_size(self):
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
graph_batch_size, tp_size)
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
new_max_num_reqs = math.ceil(
|
||||
max(new_graph_batch_sizes) / self.decode_token_per_req)
|
||||
if self.max_num_reqs != new_max_num_reqs:
|
||||
logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
if not (self.decode_token_per_req > 1 and self.is_kv_consumer):
|
||||
# Do not update scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
# Since FIA need extra space for padding
|
||||
# Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
self.scheduler_config.max_num_seqs = new_max_num_reqs
|
||||
|
||||
if new_graph_batch_sizes != self.torchair_graph_batch_sizes:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._build_drafter_prepare_inputs_torchair_param()
|
||||
else:
|
||||
return True
|
||||
1333
vllm_npu/torchair/torchair_sfa.py
Normal file
1333
vllm_npu/torchair/torchair_sfa.py
Normal file
File diff suppressed because it is too large
Load Diff
63
vllm_npu/torchair/torchair_worker.py
Normal file
63
vllm_npu/torchair/torchair_worker.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.torchair.torchair_model_runner import NPUTorchairModelRunner
|
||||
from vllm_npu.torchair.utils import (check_kv_cache_bytes_cache_exist,
|
||||
delete_torchair_cache_file,
|
||||
read_kv_cache_bytes_from_file)
|
||||
from vllm_npu.worker.worker_v1 import NPUWorker
|
||||
|
||||
|
||||
class NPUTorchairWorker(NPUWorker):
|
||||
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""
|
||||
|
||||
available_kv_cache_memory = super().determine_available_memory()
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
return available_kv_cache_memory
|
||||
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
if check_kv_cache_bytes_cache_exist():
|
||||
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
|
||||
torch.distributed.get_rank())
|
||||
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
|
||||
logger.info(
|
||||
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
|
||||
)
|
||||
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
|
||||
return old_kv_cache_bytes
|
||||
else:
|
||||
logger.info(
|
||||
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
|
||||
)
|
||||
delete_torchair_cache_file()
|
||||
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.vllm_npu_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
|
||||
available_kv_cache_memory -= bytes_floating_tolerance
|
||||
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
|
||||
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
||||
return available_kv_cache_memory
|
||||
|
||||
def init_device(self):
|
||||
"""Override init_device to init torchair model runner"""
|
||||
device = self._init_device()
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)
|
||||
275
vllm_npu/torchair/utils.py
Normal file
275
vllm_npu/torchair/utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import fcntl
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torchair.scope import super_kernel as _super_kernel
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
||||
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
||||
TORCHAIR_CACHE_DIR = os.path.join(
|
||||
os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchairCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _file_lock(file_descriptor, lock_type):
|
||||
fcntl.flock(file_descriptor, lock_type)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def _get_torchair_current_work_dir(file_name=None):
|
||||
if file_name is None:
|
||||
return TORCHAIR_CACHE_DIR
|
||||
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
|
||||
|
||||
|
||||
def check_torchair_cache_exist():
|
||||
res = False
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
if os.path.exists(torch_air_abs_path):
|
||||
file_list = os.listdir(torch_air_abs_path)
|
||||
if len(file_list) != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def check_kv_cache_bytes_cache_exist():
|
||||
res = False
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
if os.path.exists(kv_cache_bytes_cache_abs_path):
|
||||
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
|
||||
if len(file_list) != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def read_kv_cache_bytes_from_file(rank) -> int:
|
||||
kv_cache_bytes = -1
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
kv_cache_bytes_file = os.path.join(
|
||||
kv_cache_bytes_cache_abs_path,
|
||||
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
||||
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
|
||||
with _file_lock(f, fcntl.LOCK_SH):
|
||||
kv_cache_bytes = int(f.readline())
|
||||
return kv_cache_bytes
|
||||
|
||||
|
||||
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
|
||||
kv_cache_bytes_file = os.path.join(
|
||||
kv_cache_bytes_cache_abs_path,
|
||||
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
||||
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
|
||||
with _file_lock(f, fcntl.LOCK_EX):
|
||||
f.write(f"{kv_cache_bytes}")
|
||||
|
||||
|
||||
def delete_torchair_cache_file():
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
try:
|
||||
shutil.rmtree(torch_air_abs_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
||||
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
||||
|
||||
|
||||
def npu_wait_tensor(self: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
def converting_weight_acl_format(model, format):
|
||||
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
||||
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
||||
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
||||
# conversion when using torchair graph mode on 300I Duo platform.
|
||||
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
||||
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
if format == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz(module.w13_weight.data.dtype):
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
module.w2_weight.data, format)
|
||||
|
||||
|
||||
def register_torchair_model():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_npu.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_npu.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_npu.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_npu.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2ForCausalLM",
|
||||
"vllm_npu.torchair.models.qwen2:CustomQwen2ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_npu.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_npu.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
|
||||
|
||||
def torchair_quant_method_register():
|
||||
from vllm_npu.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP
|
||||
from vllm_npu.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
from vllm_npu.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW8A8DynamicLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||
"linear"] = TorchairAscendW8A8DynamicLinearMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||
"moe"] = TorchairAscendW8A8DynamicFusedMoEMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||
"linear"] = TorchairAscendW4A8DynamicLinearMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||
"moe"] = TorchairAscendW4A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
def torchair_ops_patch():
|
||||
from vllm_npu.ops.activation import AscendSiluAndMul
|
||||
from vllm_npu.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
|
||||
from vllm_npu.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||
from vllm_npu.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
from vllm_npu.torchair.ops import (torchair_activation,
|
||||
torchair_layernorm)
|
||||
from vllm_npu.torchair.ops.torchair_rotary_embedding import (
|
||||
deepseek_rope_init_func, native_rope_deepseek_forward,
|
||||
qwen_rope_init_func, rope_forward)
|
||||
from vllm_npu.torchair.ops.torchair_vocab_parallel_embedding import \
|
||||
vocab_embedding_forward
|
||||
|
||||
AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
|
||||
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]
|
||||
|
||||
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
|
||||
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]
|
||||
|
||||
AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendQuantRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendQuantRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
|
||||
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
|
||||
|
||||
|
||||
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
||||
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): rm_router_logits
|
||||
# dp>1 will trigger
|
||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if dp_size > 1:
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): all_reduce merge
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
||||
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user