mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
256 lines
11 KiB
Python
256 lines
11 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.parameter import Parameter
|
|
from vllm.distributed import divide
|
|
from vllm.distributed.parallel_state import get_tp_group
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
|
|
VocabParallelEmbedding, pad_vocab_size)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
from vllm_npu.distributed.parallel_state import get_lmhead_tp_group
|
|
from vllm_npu.utils import lmhead_tp_enable
|
|
|
|
|
|
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
|
"""
|
|
Register VocabParallelEmbedding as a custom op for Ascend.
|
|
AscendVocabParallelEmbedding support different communication parallel groups
|
|
Added the feature of lmheadTP in pure dp scenario
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
org_num_embeddings: Optional[int] = None,
|
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = ""):
|
|
nn.Module.__init__(self)
|
|
|
|
if lmhead_tp_enable() and prefix.find("head") != -1:
|
|
self.comm_group = get_lmhead_tp_group()
|
|
else:
|
|
self.comm_group = get_tp_group()
|
|
|
|
self.tp_size = self.comm_group.world_size
|
|
self.tp_rank = self.comm_group.rank_in_group
|
|
|
|
self.num_embeddings = num_embeddings
|
|
self.padding_size = padding_size
|
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
|
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
|
self.padding_size)
|
|
self.num_embeddings_padded = pad_vocab_size(
|
|
self.org_vocab_size_padded + num_added_embeddings,
|
|
self.padding_size)
|
|
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
|
|
|
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
|
self.org_vocab_size_padded,
|
|
self.num_embeddings,
|
|
self.org_vocab_size,
|
|
self.tp_rank, self.tp_size)
|
|
self.embedding_dim = embedding_dim
|
|
quant_method = None
|
|
if quant_config is not None:
|
|
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
|
if quant_method is None:
|
|
quant_method = UnquantizedEmbeddingMethod()
|
|
|
|
# If we are making an embedding layer, then our quantization linear
|
|
# method must implement the embedding operation. If we are another
|
|
# layer type like ParallelLMHead, this is not important.
|
|
is_embedding_layer = type(self) is VocabParallelEmbedding
|
|
quant_method_implements_embedding = method_has_implemented_embedding(
|
|
type(quant_method))
|
|
if is_embedding_layer and not quant_method_implements_embedding:
|
|
raise NotImplementedError(
|
|
f"The class {type(quant_method).__name__} must implement "
|
|
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
|
|
|
self.quant_method: QuantizeMethodBase = quant_method
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
self.params_dtype = params_dtype
|
|
# Divide the weight matrix along the vocaburaly dimension.
|
|
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
|
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
|
self.tp_size)
|
|
assert (self.shard_indices.num_elements_padded ==
|
|
self.num_embeddings_per_partition)
|
|
self.num_org_embeddings_per_partition = (
|
|
self.shard_indices.org_vocab_end_index -
|
|
self.shard_indices.org_vocab_start_index)
|
|
self.num_added_embeddings_per_partition = (
|
|
self.shard_indices.added_vocab_end_index -
|
|
self.shard_indices.added_vocab_start_index)
|
|
|
|
self.quant_method.create_weights(self,
|
|
self.embedding_dim,
|
|
[self.num_embeddings_per_partition],
|
|
self.embedding_dim,
|
|
self.num_embeddings_padded,
|
|
params_dtype=params_dtype,
|
|
weight_loader=self.weight_loader)
|
|
|
|
def _get_masked_input_and_mask(
|
|
self, input_: torch.Tensor, org_vocab_start_index: int,
|
|
org_vocab_end_index: int, num_org_vocab_padding: int,
|
|
added_vocab_start_index: int,
|
|
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# torch.compile will fuse all of the pointwise ops below
|
|
# into a single kernel, making it very fast
|
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
|
input_ < org_vocab_end_index)
|
|
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
|
|
if added_vocab_start_index == added_vocab_end_index:
|
|
valid_offset = (org_vocab_start_index * org_vocab_mask)
|
|
vocab_mask = org_vocab_mask
|
|
else:
|
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
|
input_ < added_vocab_end_index)
|
|
added_offset = added_vocab_start_index - (
|
|
org_vocab_end_index -
|
|
org_vocab_start_index) - num_org_vocab_padding
|
|
valid_offset = (org_vocab_start_index *
|
|
org_vocab_mask) + (added_offset * added_vocab_mask)
|
|
vocab_mask = org_vocab_mask | added_vocab_mask
|
|
# Adapt end.
|
|
input_ = vocab_mask * (input_ - valid_offset)
|
|
return input_, ~vocab_mask
|
|
|
|
def forward(self, input_):
|
|
if self.tp_size > 1:
|
|
# Build the mask.
|
|
masked_input, input_mask = self._get_masked_input_and_mask(
|
|
input_, self.shard_indices.org_vocab_start_index,
|
|
self.shard_indices.org_vocab_end_index,
|
|
self.shard_indices.num_org_vocab_padding,
|
|
self.shard_indices.added_vocab_start_index,
|
|
self.shard_indices.added_vocab_end_index)
|
|
else:
|
|
masked_input = input_
|
|
# Get the embeddings.
|
|
output_parallel = self.quant_method.embedding(self,
|
|
masked_input.long())
|
|
# Mask the output embedding.
|
|
if self.tp_size > 1:
|
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
|
# Reduce across all the model parallel GPUs.
|
|
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
|
return output
|
|
|
|
|
|
class AscendParallelLMHead(ParallelLMHead):
|
|
"""
|
|
Register ParallelLMHead as a custom op for Ascend."""
|
|
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
bias: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
org_num_embeddings: Optional[int] = None,
|
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = ""):
|
|
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
|
|
embedding_dim, params_dtype,
|
|
org_num_embeddings, padding_size,
|
|
quant_config, prefix)
|
|
|
|
self.quant_config = quant_config
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.num_embeddings_per_partition,
|
|
dtype=params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
|
|
class AscendLogitsProcessor(LogitsProcessor):
|
|
"""
|
|
Register LogitsProcessor as a custom op for Ascend.
|
|
Added the feature of lmheadTP in pure dp scenario
|
|
"""
|
|
|
|
def _get_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: AscendParallelLMHead,
|
|
embedding_bias: Optional[torch.Tensor] = None,
|
|
) -> Optional[torch.Tensor]:
|
|
if lmhead_tp_enable():
|
|
return self._get_logits_lmheadtp(hidden_states, lm_head,
|
|
embedding_bias)
|
|
else:
|
|
return self._get_logits_normal(hidden_states, lm_head,
|
|
embedding_bias)
|
|
|
|
def _get_logits_lmheadtp(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: AscendParallelLMHead,
|
|
embedding_bias: Optional[torch.Tensor],
|
|
) -> Optional[torch.Tensor]:
|
|
# Gather hidden states from all devices in tensor parallel group
|
|
gathered_hidden_states = get_lmhead_tp_group().all_gather(
|
|
hidden_states, dim=0)
|
|
local_logits = lm_head.quant_method.apply(lm_head,
|
|
gathered_hidden_states,
|
|
bias=embedding_bias)
|
|
# Gather logits for tensor parallel
|
|
logits = get_lmhead_tp_group().all_to_all(local_logits)
|
|
# Remove paddings in vocab (if any)
|
|
if logits is not None:
|
|
logits = logits[..., :self.org_vocab_size]
|
|
return logits
|
|
|
|
def _get_logits_normal(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: AscendParallelLMHead,
|
|
embedding_bias: Optional[torch.Tensor],
|
|
) -> Optional[torch.Tensor]:
|
|
local_logits = lm_head.quant_method.apply(lm_head,
|
|
hidden_states,
|
|
bias=embedding_bias)
|
|
# Gather logits for tensor parallel
|
|
logits = self._gather_logits(local_logits)
|
|
|
|
# Remove paddings in vocab (if any)
|
|
if logits is not None:
|
|
logits = logits[..., :self.org_vocab_size]
|
|
|
|
return logits
|