mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
1119 lines
45 KiB
Python
1119 lines
45 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
import torch_npu
|
|
from torch import nn
|
|
from torch.nn import Parameter
|
|
from transformers import PretrainedConfig
|
|
from vllm.attention import Attention, AttentionMetadata
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import (divide, get_pp_group,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
|
get_tp_group, get_world_group)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (LinearBase,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.interfaces import SupportsPP
|
|
from vllm.model_executor.models.utils import (
|
|
extract_layer_index, is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.v1.sample.sampler import Sampler
|
|
|
|
from vllm_npu.ascend_config import get_ascend_config
|
|
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
|
|
|
_ROUTER_SCALE = None
|
|
|
|
|
|
def use_h2p():
|
|
# only use H2P when dp_size > 1.
|
|
if get_dp_group().world_size > 1:
|
|
return True
|
|
return False
|
|
|
|
|
|
# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear.
|
|
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp).
|
|
class CustomMergedColumnParallelLinear(LinearBase):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_sizes: list[int],
|
|
bias: bool = True,
|
|
gather_output: bool = False,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
):
|
|
# Divide the weight matrix along the last dimension.
|
|
output_size = sum(output_sizes)
|
|
self.output_sizes = output_sizes
|
|
self.tp_size = get_tp_group().world_size
|
|
self.input_size_per_partition = input_size
|
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
|
self.output_partition_sizes = [self.output_size_per_partition]
|
|
# If QKV or MergedColumn, use output size of each partition.
|
|
if hasattr(self, "output_sizes"):
|
|
self.output_partition_sizes = [
|
|
divide(output_size, self.tp_size)
|
|
for output_size in self.output_sizes
|
|
]
|
|
|
|
super().__init__(input_size,
|
|
output_size,
|
|
skip_bias_add,
|
|
params_dtype,
|
|
quant_config,
|
|
prefix,
|
|
return_bias=return_bias)
|
|
|
|
self.gather_output = gather_output
|
|
|
|
if output_sizes is None:
|
|
output_sizes = [output_size]
|
|
|
|
assert self.quant_method is not None
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
input_size_per_partition=self.input_size_per_partition,
|
|
output_partition_sizes=self.output_partition_sizes,
|
|
input_size=self.input_size,
|
|
output_size=self.output_size,
|
|
params_dtype=self.params_dtype,
|
|
weight_loader=self.weight_loader)
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size_per_partition,
|
|
dtype=params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor,
|
|
loaded_shard_id: int):
|
|
param_data = param.data
|
|
output_dim = getattr(param, "output_dim", None)
|
|
|
|
assert loaded_shard_id < len(self.output_sizes)
|
|
|
|
tp_rank = get_tp_group().rank_in_group
|
|
tp_size = get_tp_group().world_size
|
|
if output_dim is not None:
|
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
|
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
|
|
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
|
param_data = param_data.narrow(output_dim, shard_offset,
|
|
shard_size)
|
|
start_idx = tp_rank * shard_size
|
|
if not is_sharded_weight:
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
shard_size)
|
|
else:
|
|
ignore_warning = getattr(param, "ignore_warning", False)
|
|
if not ignore_warning:
|
|
logger.warning(
|
|
"Loading a weight without `output_dim` attribute in "
|
|
"MergedColumnParallelLinear, assume the weight is "
|
|
"the same for all partitions.")
|
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
param_data.copy_(loaded_weight)
|
|
|
|
def forward(
|
|
self, input_
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
# Matrix multiply.
|
|
assert self.quant_method is not None
|
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
|
output = output_parallel
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
if not self.return_bias:
|
|
return output
|
|
return output, output_bias
|
|
|
|
|
|
# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear.
|
|
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp)
|
|
# and detach communication to enable customized communication algorithms(e.g., H2P).
|
|
class CustomRowParallelLinear(LinearBase):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
bias: bool = True,
|
|
input_is_parallel: bool = True,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
reduce_results: bool = True,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
group=None,
|
|
):
|
|
# Divide the weight matrix along the first dimension.
|
|
self.group = group if group is not None else get_tp_group()
|
|
self.tp_rank = self.group.rank_in_group
|
|
self.tp_size = self.group.world_size
|
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
|
self.output_size_per_partition = output_size
|
|
self.output_partition_sizes = [output_size]
|
|
|
|
super().__init__(input_size,
|
|
output_size,
|
|
skip_bias_add,
|
|
params_dtype,
|
|
quant_config,
|
|
prefix,
|
|
return_bias=return_bias)
|
|
|
|
self.input_is_parallel = input_is_parallel
|
|
self.reduce_results = reduce_results
|
|
|
|
assert self.quant_method is not None
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
input_size_per_partition=self.input_size_per_partition,
|
|
output_partition_sizes=self.output_partition_sizes,
|
|
input_size=self.input_size,
|
|
output_size=self.output_size,
|
|
params_dtype=self.params_dtype,
|
|
weight_loader=self.weight_loader)
|
|
if not reduce_results and (bias and not skip_bias_add):
|
|
raise ValueError("When not reduce the results, adding bias to the "
|
|
"results can lead to incorrect results")
|
|
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size, dtype=params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
tp_rank = self.group.rank_in_group
|
|
input_dim = getattr(param, "input_dim", None)
|
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
|
is_sharded_weight = is_sharded_weight
|
|
|
|
param_data = param.data
|
|
if input_dim is not None and not is_sharded_weight:
|
|
shard_size = param_data.shape[input_dim]
|
|
start_idx = tp_rank * shard_size
|
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
|
shard_size)
|
|
|
|
# Special case for loading scales off disk, which often do not
|
|
# have a shape (such as in the case of AutoFP8).
|
|
if len(loaded_weight.shape) == 0:
|
|
loaded_weight = loaded_weight.reshape(1)
|
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
param_data.copy_(loaded_weight)
|
|
|
|
def forward(
|
|
self, input_
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
input_parallel = input_
|
|
|
|
# Matrix multiply.
|
|
assert self.quant_method is not None
|
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
|
# bias will not get added more than once in TP>1 case)
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
if not self.return_bias:
|
|
return output
|
|
return output, output_bias
|
|
|
|
|
|
class PanguProMoEMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
reduce_results: bool = True,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
if not use_h2p():
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
else:
|
|
self.gate_up_proj = CustomMergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.down_proj = CustomRowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
|
|
if hidden_act != "silu":
|
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now.")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def topk_wrapper(num_voted_experts):
|
|
|
|
def pangu_group8_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool = False,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
global_num_experts: int = 0,
|
|
):
|
|
scores = F.softmax(gating_output, dim=1)
|
|
num_tokens = scores.shape[0]
|
|
router_scale = _ROUTER_SCALE.squeeze( # type: ignore
|
|
)
|
|
# TODO: support disable expert parallel
|
|
ep_size = get_ep_group().world_size
|
|
local_num_experts = global_num_experts // ep_size
|
|
local_num_group = topk // ep_size
|
|
experts_per_group = global_num_experts // topk
|
|
local_group_start = get_ep_group().rank_in_group * local_num_experts
|
|
local_group_end = (get_ep_group().rank_in_group +
|
|
1) * local_num_experts
|
|
scores = F.softmax(gating_output, dim=1)
|
|
scores = scores[..., local_group_start:local_group_end]
|
|
|
|
router_weights = router_scale[local_group_start:local_group_end]
|
|
|
|
if num_voted_experts == 8:
|
|
# use original topk
|
|
topk_weights, topk_ids = torch.max(scores.view(
|
|
scores.shape[0], local_num_group, -1),
|
|
dim=-1)
|
|
bias = torch.arange(0,
|
|
local_num_experts,
|
|
experts_per_group,
|
|
device=scores.device,
|
|
dtype=torch.int32).unsqueeze(0)
|
|
topk_ids = topk_ids.to(torch.int32) + bias
|
|
|
|
else:
|
|
group_expert_indices = torch.arange(experts_per_group,
|
|
dtype=torch.int32,
|
|
device=scores.device).view(
|
|
1, 1, -1)
|
|
group_expert_offset = (torch.arange(
|
|
local_num_group, dtype=torch.int32, device=scores.device) *
|
|
experts_per_group).unsqueeze(0)
|
|
expert_index_range = torch.arange(experts_per_group,
|
|
dtype=torch.int32,
|
|
device=scores.device)
|
|
|
|
scores_grouped = scores.view(num_tokens, local_num_group,
|
|
experts_per_group)
|
|
best_expert_idx = torch.argmax(scores_grouped,
|
|
dim=2) # (num_tokens, num_groups)
|
|
vote_mask = (best_expert_idx.unsqueeze(-1).to(
|
|
torch.int32) == group_expert_indices)
|
|
|
|
expert_vote_freq = vote_mask.sum(dim=0)
|
|
|
|
sorted_indices = torch.argsort(expert_vote_freq,
|
|
dim=1,
|
|
descending=True).to(torch.int32)
|
|
topk_experts = sorted_indices[:, :num_voted_experts]
|
|
keep_mask = ((
|
|
topk_experts.unsqueeze(-1) == expert_index_range).any(
|
|
dim=1)).unsqueeze(0)
|
|
|
|
masked_scores = torch.where(keep_mask, scores_grouped, 0)
|
|
|
|
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
|
|
best_pos_in_group = best_pos_in_group.to(torch.int32)
|
|
topk_ids = (best_pos_in_group + group_expert_offset).to(
|
|
torch.int32)
|
|
|
|
flatten_topk_ids = topk_ids.view(-1)
|
|
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
|
|
topk_ids.shape)
|
|
topk_weights *= router_weights
|
|
|
|
return topk_weights, topk_ids
|
|
|
|
return pangu_group8_topk
|
|
|
|
|
|
class PanguProMoESparseMoeBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.num_experts = config.num_experts
|
|
|
|
if self.tp_size > config.num_experts:
|
|
raise ValueError(
|
|
f"Tensor parallel size {self.tp_size} is greater than "
|
|
f"the number of experts {config.num_experts}.")
|
|
|
|
self.num_experts_per_tok = config.num_experts_per_tok
|
|
self.router_scale = torch.nn.Parameter(
|
|
torch.ones((1, self.num_experts)))
|
|
|
|
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
|
|
# good performance without sacrifice too much accuracy. for other platform,
|
|
# this is set to 8 to use original pangu grouped topk.
|
|
num_voted_experts = 5 if is_310p() else 8
|
|
|
|
self.experts = FusedMoE(
|
|
num_experts=config.num_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
reduce_results=False,
|
|
quant_config=quant_config,
|
|
custom_routing_function=topk_wrapper(num_voted_experts),
|
|
prefix=f"{prefix}.experts",
|
|
)
|
|
self.use_ep = self.experts.use_ep
|
|
|
|
self.gate = ReplicatedLinear(
|
|
config.hidden_size,
|
|
config.num_experts,
|
|
bias=False,
|
|
quant_config=None,
|
|
prefix=f"{prefix}.gate",
|
|
)
|
|
|
|
if config.shared_expert_intermediate_size > 0:
|
|
self.shared_expert = PanguProMoEMLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.shared_expert_intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
reduce_results=False,
|
|
prefix=f"{prefix}.shared_expert",
|
|
)
|
|
else:
|
|
self.shared_expert = None # type: ignore
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
|
num_tokens, hidden_dim = hidden_states.shape
|
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
shared_output = None
|
|
if self.shared_expert is not None:
|
|
shared_output = self.shared_expert(hidden_states)
|
|
# router_logits: (num_tokens, n_experts)
|
|
router_logits, _ = self.gate(hidden_states)
|
|
global _ROUTER_SCALE
|
|
_ROUTER_SCALE = self.router_scale
|
|
|
|
# TODO(angazenn): Does not support MC2 currently
|
|
get_forward_context().moe_comm_method_name = "allgathercommimpl"
|
|
|
|
if not use_h2p():
|
|
final_hidden_states = self.experts.forward_impl(
|
|
hidden_states=hidden_states, router_logits=router_logits)
|
|
else:
|
|
# TODO: when using h2p, we have to skip communication in vLLM
|
|
# native FusedMoE. here we need to design a better FusedMoE
|
|
# (maybe using AscendFusedMoE) to enable these different
|
|
# communication schema.
|
|
final_hidden_states = self.experts.quant_method.apply(
|
|
layer=self.experts,
|
|
x=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.experts.top_k,
|
|
renormalize=False,
|
|
use_grouped_topk=False,
|
|
global_num_experts=self.experts.global_num_experts,
|
|
expert_map=self.experts.expert_map,
|
|
custom_routing_function=self.experts.custom_routing_function,
|
|
apply_router_weight_on_input=self.experts.
|
|
apply_router_weight_on_input)
|
|
|
|
if shared_output is not None:
|
|
final_hidden_states = final_hidden_states + shared_output
|
|
if not use_h2p():
|
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
|
final_hidden_states)
|
|
|
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
|
|
|
|
|
class PanguProMoEAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
self.head_dim = hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
|
|
if use_h2p():
|
|
self.o_proj = CustomRowParallelLinear(self.total_num_heads *
|
|
self.head_dim,
|
|
hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
group=get_tp_group())
|
|
else:
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
ascend_config = get_ascend_config()
|
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: Optional[torch.Tensor] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
if self.torchair_graph_enabled:
|
|
forward_kwargs = {'trace_flag': False}
|
|
output_shape = q.shape
|
|
attn_output = torch.empty(output_shape,
|
|
dtype=q.dtype,
|
|
device=q.device)
|
|
forward_kwargs['output'] = attn_output
|
|
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
|
|
attn_metadata,
|
|
**forward_kwargs)
|
|
else:
|
|
attn_output = self.attn(q, k, v)
|
|
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class PanguProMoEDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
|
8192)
|
|
|
|
self.self_attn = PanguProMoEAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
|
|
# `mlp_only_layers` in the config.
|
|
layer_idx = extract_layer_index(prefix)
|
|
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
|
config.mlp_only_layers)
|
|
if (layer_idx not in mlp_only_layers) and (config.num_experts > 0):
|
|
self.mlp = PanguProMoESparseMoeBlock(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
else:
|
|
self.mlp = PanguProMoEMLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
kv_cache: Optional[torch.Tensor] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
h2p_unpad_idx: Optional[torch.Tensor] = None,
|
|
h2p_pad_idx: Optional[torch.Tensor] = None,
|
|
is_start_layer: Optional[bool] = False,
|
|
) -> torch.Tensor:
|
|
need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \
|
|
and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]
|
|
tp_size = get_tp_group().world_size
|
|
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
if use_h2p():
|
|
if is_start_layer:
|
|
if need_h2p_pad:
|
|
residual = residual.index_select(dim=0, index=h2p_pad_idx)
|
|
residual = torch.tensor_split(
|
|
residual, tp_size)[get_tp_group().rank_in_group]
|
|
else:
|
|
if tp_size > 1:
|
|
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
|
if need_h2p_pad:
|
|
hidden_states = hidden_states.index_select(
|
|
dim=0, index=h2p_unpad_idx)
|
|
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
|
|
if use_h2p():
|
|
if need_h2p_pad:
|
|
hidden_states = hidden_states.index_select(dim=0,
|
|
index=h2p_pad_idx)
|
|
if tp_size > 1:
|
|
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
|
hidden_states,
|
|
"sum",
|
|
scatter_dim=0,
|
|
group=get_tp_group().device_group)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
|
|
if use_h2p():
|
|
all_rank_group = get_world_group().device_group
|
|
output_size = (hidden_states.shape[0] *
|
|
get_world_group().world_size,
|
|
hidden_states.shape[1])
|
|
# Allocate output tensor.
|
|
output_tensor = torch.empty(output_size,
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device)
|
|
# All-gather.
|
|
dist.all_gather_into_tensor(output_tensor,
|
|
hidden_states,
|
|
group=all_rank_group)
|
|
hidden_states = output_tensor
|
|
|
|
hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata)
|
|
|
|
if use_h2p():
|
|
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
|
hidden_states,
|
|
"sum",
|
|
scatter_dim=0,
|
|
group=get_world_group().device_group)
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
@support_torch_compile
|
|
class PanguProMoEModel(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.embed_tokens")
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: PanguProMoEDecoderLayer(config=config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.make_empty_intermediate_tensors = (
|
|
make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size))
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
if use_h2p():
|
|
# calculate necessary padding/unpadding idx before model forward.
|
|
|
|
# the attn_metadata will be passed directly when use torchair.
|
|
# if attn_meatadata is not passed, we try to get it from forward_context.
|
|
if attn_metadata is None:
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
|
|
max_tokens_across_dp = get_forward_context().max_tokens_across_dp
|
|
|
|
tp_size = get_tp_group().world_size
|
|
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
|
|
# we need pad it before if the shape can't be divided by group size.
|
|
# for h2p, we need pad it so that it can be divided by tp_size.
|
|
h2p_padded_len = (
|
|
tp_size - (max_tokens_across_dp % tp_size)
|
|
) % tp_size + max_tokens_across_dp - hidden_states.shape[0]
|
|
h2p_unpad_idx = torch.arange(hidden_states.shape[0],
|
|
device=hidden_states.device,
|
|
dtype=torch.int32)
|
|
h2p_pad_idx = torch.cat([
|
|
h2p_unpad_idx,
|
|
torch.zeros(h2p_padded_len,
|
|
dtype=torch.int32,
|
|
device=hidden_states.device)
|
|
])
|
|
else:
|
|
h2p_unpad_idx = None
|
|
h2p_pad_idx = None
|
|
|
|
for i in range(self.start_layer, self.end_layer):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions, hidden_states, residual,
|
|
kv_caches[i -
|
|
self.start_layer] if kv_caches is not None else None,
|
|
attn_metadata, h2p_unpad_idx, h2p_pad_idx,
|
|
i == self.start_layer)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
if use_h2p():
|
|
if get_tp_group().world_size > 1:
|
|
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
|
if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]:
|
|
hidden_states = hidden_states.index_select(dim=0,
|
|
index=h2p_unpad_idx)
|
|
return hidden_states
|
|
|
|
|
|
class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
|
|
|
fall_back_to_pt_during_load = False
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
"experts":
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = PanguProMoEModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.lm_head",
|
|
)
|
|
if self.config.tie_word_embeddings:
|
|
self.lm_head.weight = self.model.embed_tokens.weight
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = Sampler()
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata, intermediate_tensors,
|
|
inputs_embeds)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata=None, # type: ignore
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: Optional[torch.Tensor],
|
|
sampling_metadata, # type: ignore
|
|
):
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
tp_size = get_tp_group().world_size
|
|
tp_rank = get_tp_group().rank_in_group
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=self.config.num_experts)
|
|
|
|
params_dict = dict(self.named_parameters()) # from model
|
|
loaded_params: Set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
# =======================================================
|
|
# BF: add this to load with less layers
|
|
if 'layers' in name:
|
|
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
|
if layer_idx >= self.model.end_layer:
|
|
continue
|
|
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
if "module" in name:
|
|
continue
|
|
|
|
if name.endswith('kv_cache_offset'):
|
|
continue
|
|
|
|
if name.endswith("k_proj.kv_cache_scale"):
|
|
remapped_kv_scale_name = name.replace(
|
|
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
|
|
if remapped_kv_scale_name not in params_dict:
|
|
logger.warning_once(
|
|
"Found kv scale in the checkpoint "
|
|
f"(e.g. {name}), but not found the expected "
|
|
f"name in the model "
|
|
f"(e.g. {remapped_kv_scale_name}). "
|
|
"kv-scale is not loaded.")
|
|
continue
|
|
else:
|
|
name = remapped_kv_scale_name
|
|
param = params_dict[name]
|
|
loaded_weight = torch.tensor_split(loaded_weight,
|
|
tp_size,
|
|
dim=0)[tp_rank]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
if name.endswith("v_proj.kv_cache_scale"):
|
|
remapped_kv_scale_name = name.replace(
|
|
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
|
|
if remapped_kv_scale_name not in params_dict:
|
|
logger.warning_once(
|
|
"Found kv scale in the checkpoint "
|
|
f"(e.g. {name}), but not found the expected "
|
|
f"name in the model "
|
|
f"(e.g. {remapped_kv_scale_name}). "
|
|
"kv-scale is not loaded.")
|
|
continue
|
|
else:
|
|
name = remapped_kv_scale_name
|
|
param = params_dict[name]
|
|
loaded_weight = torch.tensor_split(loaded_weight,
|
|
tp_size,
|
|
dim=0)[tp_rank]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
# Skip non-stacked layers and experts (experts handled below).
|
|
if weight_name not in name:
|
|
continue
|
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
# Since we handle the experts below in expert_params_mapping,
|
|
# we need to skip here BEFORE we update the name, otherwise
|
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
# will then be updated below in expert_params_mapping
|
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
if "mlp.experts" in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
and name not in params_dict):
|
|
continue
|
|
|
|
# Skip layers on other devices.
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
if name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
for mapping in expert_params_mapping:
|
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
if weight_name not in name:
|
|
continue
|
|
# breakpoint()
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip layers on other devices.
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
# Skip loading extra bias for GPTQ models.
|
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
and name not in params_dict):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
# breakpoint()
|
|
weight_loader(param,
|
|
loaded_weight,
|
|
name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
|
and name not in params_dict):
|
|
continue
|
|
# Skip layers on other devices.
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
# Remapping the name of FP8 kv-scale.
|
|
if name.endswith("kv_scale"):
|
|
remapped_kv_scale_name = name.replace(
|
|
".kv_scale", ".attn.kv_scale")
|
|
if remapped_kv_scale_name not in params_dict:
|
|
logger.warning_once(
|
|
"Found kv scale in the checkpoint "
|
|
f"(e.g. {name}), but not found the expected "
|
|
f"name in the model "
|
|
f"(e.g. {remapped_kv_scale_name}). "
|
|
"kv-scale is not loaded.")
|
|
continue
|
|
else:
|
|
name = remapped_kv_scale_name
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
if is_310p() and "head" in name:
|
|
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
|
|
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
|
|
# by linear, we manually cast the format here.
|
|
param.data = torch_npu.npu_format_cast(param.data,
|
|
ACL_FORMAT_FRACTAL_NZ)
|
|
return loaded_params
|