Files
vllm-npu-plugin/vllm_npu/torchair/models/torchair_pangu_moe.py
2026-02-10 23:08:39 +08:00

1119 lines
45 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from torch.nn import Parameter
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group, get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.v1.sample.sampler import Sampler
from vllm_npu.ascend_config import get_ascend_config
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
_ROUTER_SCALE = None
def use_h2p():
# only use H2P when dp_size > 1.
if get_dp_group().world_size > 1:
return True
return False
# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear.
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp).
class CustomMergedColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
# Divide the weight matrix along the last dimension.
output_size = sum(output_sizes)
self.output_sizes = output_sizes
self.tp_size = get_tp_group().world_size
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor,
loaded_shard_id: int):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tp_group().rank_in_group
tp_size = get_tp_group().world_size
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear.
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp)
# and detach communication to enable customized communication algorithms(e.g., H2P).
class CustomRowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
group=None,
):
# Divide the weight matrix along the first dimension.
self.group = group if group is not None else get_tp_group()
self.tp_rank = self.group.rank_in_group
self.tp_size = self.group.world_size
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = self.group.rank_in_group
input_dim = getattr(param, "input_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
is_sharded_weight = is_sharded_weight
param_data = param.data
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
input_parallel = input_
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output = self.quant_method.apply(self, input_parallel, bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class PanguProMoEMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
if not use_h2p():
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
else:
self.gate_up_proj = CustomMergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = CustomRowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def topk_wrapper(num_voted_experts):
def pangu_group8_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool = False,
num_expert_group: int = 0,
topk_group: int = 0,
global_num_experts: int = 0,
):
scores = F.softmax(gating_output, dim=1)
num_tokens = scores.shape[0]
router_scale = _ROUTER_SCALE.squeeze( # type: ignore
)
# TODO: support disable expert parallel
ep_size = get_ep_group().world_size
local_num_experts = global_num_experts // ep_size
local_num_group = topk // ep_size
experts_per_group = global_num_experts // topk
local_group_start = get_ep_group().rank_in_group * local_num_experts
local_group_end = (get_ep_group().rank_in_group +
1) * local_num_experts
scores = F.softmax(gating_output, dim=1)
scores = scores[..., local_group_start:local_group_end]
router_weights = router_scale[local_group_start:local_group_end]
if num_voted_experts == 8:
# use original topk
topk_weights, topk_ids = torch.max(scores.view(
scores.shape[0], local_num_group, -1),
dim=-1)
bias = torch.arange(0,
local_num_experts,
experts_per_group,
device=scores.device,
dtype=torch.int32).unsqueeze(0)
topk_ids = topk_ids.to(torch.int32) + bias
else:
group_expert_indices = torch.arange(experts_per_group,
dtype=torch.int32,
device=scores.device).view(
1, 1, -1)
group_expert_offset = (torch.arange(
local_num_group, dtype=torch.int32, device=scores.device) *
experts_per_group).unsqueeze(0)
expert_index_range = torch.arange(experts_per_group,
dtype=torch.int32,
device=scores.device)
scores_grouped = scores.view(num_tokens, local_num_group,
experts_per_group)
best_expert_idx = torch.argmax(scores_grouped,
dim=2) # (num_tokens, num_groups)
vote_mask = (best_expert_idx.unsqueeze(-1).to(
torch.int32) == group_expert_indices)
expert_vote_freq = vote_mask.sum(dim=0)
sorted_indices = torch.argsort(expert_vote_freq,
dim=1,
descending=True).to(torch.int32)
topk_experts = sorted_indices[:, :num_voted_experts]
keep_mask = ((
topk_experts.unsqueeze(-1) == expert_index_range).any(
dim=1)).unsqueeze(0)
masked_scores = torch.where(keep_mask, scores_grouped, 0)
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
best_pos_in_group = best_pos_in_group.to(torch.int32)
topk_ids = (best_pos_in_group + group_expert_offset).to(
torch.int32)
flatten_topk_ids = topk_ids.view(-1)
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
topk_ids.shape)
topk_weights *= router_weights
return topk_weights, topk_ids
return pangu_group8_topk
class PanguProMoESparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_experts = config.num_experts
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.num_experts_per_tok = config.num_experts_per_tok
self.router_scale = torch.nn.Parameter(
torch.ones((1, self.num_experts)))
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
# good performance without sacrifice too much accuracy. for other platform,
# this is set to 8 to use original pangu grouped topk.
num_voted_experts = 5 if is_310p() else 8
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
quant_config=quant_config,
custom_routing_function=topk_wrapper(num_voted_experts),
prefix=f"{prefix}.experts",
)
self.use_ep = self.experts.use_ep
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
if config.shared_expert_intermediate_size > 0:
self.shared_expert = PanguProMoEMLP(
hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_expert",
)
else:
self.shared_expert = None # type: ignore
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
global _ROUTER_SCALE
_ROUTER_SCALE = self.router_scale
# TODO(angazenn): Does not support MC2 currently
get_forward_context().moe_comm_method_name = "allgathercommimpl"
if not use_h2p():
final_hidden_states = self.experts.forward_impl(
hidden_states=hidden_states, router_logits=router_logits)
else:
# TODO: when using h2p, we have to skip communication in vLLM
# native FusedMoE. here we need to design a better FusedMoE
# (maybe using AscendFusedMoE) to enable these different
# communication schema.
final_hidden_states = self.experts.quant_method.apply(
layer=self.experts,
x=hidden_states,
router_logits=router_logits,
top_k=self.experts.top_k,
renormalize=False,
use_grouped_topk=False,
global_num_experts=self.experts.global_num_experts,
expert_map=self.experts.expert_map,
custom_routing_function=self.experts.custom_routing_function,
apply_router_weight_on_input=self.experts.
apply_router_weight_on_input)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if not use_h2p():
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class PanguProMoEAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
if use_h2p():
self.o_proj = CustomRowParallelLinear(self.total_num_heads *
self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
group=get_tp_group())
else:
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if self.torchair_graph_enabled:
forward_kwargs = {'trace_flag': False}
output_shape = q.shape
attn_output = torch.empty(output_shape,
dtype=q.dtype,
device=q.device)
forward_kwargs['output'] = attn_output
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
attn_metadata,
**forward_kwargs)
else:
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class PanguProMoEDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = PanguProMoEAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (config.num_experts > 0):
self.mlp = PanguProMoESparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = PanguProMoEMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
h2p_unpad_idx: Optional[torch.Tensor] = None,
h2p_pad_idx: Optional[torch.Tensor] = None,
is_start_layer: Optional[bool] = False,
) -> torch.Tensor:
need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \
and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]
tp_size = get_tp_group().world_size
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if use_h2p():
if is_start_layer:
if need_h2p_pad:
residual = residual.index_select(dim=0, index=h2p_pad_idx)
residual = torch.tensor_split(
residual, tp_size)[get_tp_group().rank_in_group]
else:
if tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
if need_h2p_pad:
hidden_states = hidden_states.index_select(
dim=0, index=h2p_unpad_idx)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if use_h2p():
if need_h2p_pad:
hidden_states = hidden_states.index_select(dim=0,
index=h2p_pad_idx)
if tp_size > 1:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_tp_group().device_group)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if use_h2p():
all_rank_group = get_world_group().device_group
output_size = (hidden_states.shape[0] *
get_world_group().world_size,
hidden_states.shape[1])
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=hidden_states.dtype,
device=hidden_states.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
hidden_states,
group=all_rank_group)
hidden_states = output_tensor
hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata)
if use_h2p():
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_world_group().device_group)
return hidden_states, residual
@support_torch_compile
class PanguProMoEModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: PanguProMoEDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if use_h2p():
# calculate necessary padding/unpadding idx before model forward.
# the attn_metadata will be passed directly when use torchair.
# if attn_meatadata is not passed, we try to get it from forward_context.
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
max_tokens_across_dp = get_forward_context().max_tokens_across_dp
tp_size = get_tp_group().world_size
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
# we need pad it before if the shape can't be divided by group size.
# for h2p, we need pad it so that it can be divided by tp_size.
h2p_padded_len = (
tp_size - (max_tokens_across_dp % tp_size)
) % tp_size + max_tokens_across_dp - hidden_states.shape[0]
h2p_unpad_idx = torch.arange(hidden_states.shape[0],
device=hidden_states.device,
dtype=torch.int32)
h2p_pad_idx = torch.cat([
h2p_unpad_idx,
torch.zeros(h2p_padded_len,
dtype=torch.int32,
device=hidden_states.device)
])
else:
h2p_unpad_idx = None
h2p_pad_idx = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata, h2p_unpad_idx, h2p_pad_idx,
i == self.start_layer)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if use_h2p():
if get_tp_group().world_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]:
hidden_states = hidden_states.index_select(dim=0,
index=h2p_unpad_idx)
return hidden_states
class PanguProMoEForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = PanguProMoEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata=None, # type: ignore
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata, # type: ignore
):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) # from model
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# =======================================================
# BF: add this to load with less layers
if 'layers' in name:
layer_idx = int(name.split('layers.')[-1].split('.')[0])
if layer_idx >= self.model.end_layer:
continue
if "rotary_emb.inv_freq" in name:
continue
if "module" in name:
continue
if name.endswith('kv_cache_offset'):
continue
if name.endswith("k_proj.kv_cache_scale"):
remapped_kv_scale_name = name.replace(
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
loaded_weight = torch.tensor_split(loaded_weight,
tp_size,
dim=0)[tp_rank]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.endswith("v_proj.kv_cache_scale"):
remapped_kv_scale_name = name.replace(
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
loaded_weight = torch.tensor_split(loaded_weight,
tp_size,
dim=0)[tp_rank]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# breakpoint()
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
# breakpoint()
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if is_310p() and "head" in name:
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
# by linear, we manually cast the format here.
param.data = torch_npu.npu_format_cast(param.data,
ACL_FORMAT_FRACTAL_NZ)
return loaded_params