mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
0
vllm_npu/quantization/__init__.py
Normal file
0
vllm_npu/quantization/__init__.py
Normal file
488
vllm_npu/quantization/quant_config.py
Normal file
488
vllm_npu/quantization/quant_config.py
Normal file
@@ -0,0 +1,488 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import \
|
||||
register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_npu.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_npu.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_npu.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_npu.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
from .utils import get_quant_method
|
||||
|
||||
|
||||
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
||||
class AscendQuantConfig(QuantizationConfig):
|
||||
"""Config class for Ascend
|
||||
|
||||
This class is a general class that parse quantization configs
|
||||
that are supported on ascend hardware.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
# TODO(whx): remove this adaptation after adding "shared_head"
|
||||
# to prefix of DeepSeekShareHead in vLLM.
|
||||
extra_quant_dict = {}
|
||||
for k in self.quant_description.keys():
|
||||
if "shared_head" in k:
|
||||
new_k = k.replace(".shared_head.", ".")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
self.quant_description.update(extra_quant_dict)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "AscendQuantConfig:\n" + super().__repr__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.int8, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_model_description.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
if torch.npu.is_available():
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if model_type in packed_modules_model_mapping:
|
||||
self.packed_modules_mapping = packed_modules_model_mapping[
|
||||
model_type]
|
||||
from vllm.attention.layer import Attention
|
||||
if prefix.startswith("language_model"):
|
||||
prefix = prefix.split('.', 1)[-1]
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return AscendUnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, Attention) and self.quant_description.get(
|
||||
'kv_quant_type') == 'C8':
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return AscendFusedMoEMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return AscendEmbeddingMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
return None
|
||||
|
||||
def is_layer_skipped_ascend(
|
||||
self,
|
||||
prefix: str,
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
||||
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = self.quant_description[shard_prefix +
|
||||
'.weight'] == "FLOAT"
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
elif "experts" in prefix:
|
||||
# For the experts' prefix (e.g., "model.layers.3.mlp.experts")
|
||||
# Assume all experts within the same MLP use the same quantization method
|
||||
experts_quant_description = [
|
||||
self.quant_description[layer]
|
||||
for layer in self.quant_description if prefix in layer
|
||||
]
|
||||
is_skipped = any(quantization == "FLOAT"
|
||||
for quantization in experts_quant_description)
|
||||
else:
|
||||
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
packed_modules_model_mapping = {
|
||||
"qwen3_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
},
|
||||
"deepseek_v2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"deepseek_v3": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"kimi_k2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"deepseek_v32": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
"deepseek_mtp": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"qwen3_next": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
||||
},
|
||||
"qwen2_5_vl": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
},
|
||||
"glm4_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AscendLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix, "linear",
|
||||
packed_modules_mapping)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype)
|
||||
|
||||
# Extract packing information (if present)
|
||||
packed_dim = weight_dict.pop("_packed_dim", None)
|
||||
packed_factor = weight_dict.pop("_packed_factor", None)
|
||||
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
# Set packing attributes if the weight is packed
|
||||
if packed_dim is not None and packed_factor is not None:
|
||||
set_weight_attrs(param, {
|
||||
"packed_dim": packed_dim,
|
||||
"packed_factor": packed_factor
|
||||
})
|
||||
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
param = PerTensorScaleParameter(data=pertensor_param,
|
||||
weight_loader=weight_loader)
|
||||
# disable warning
|
||||
param.ignore_warning = True
|
||||
layer.register_parameter(pertensor_name, param)
|
||||
param.weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
# NOTE: In w4a8 quantization implementation,
|
||||
# for down_proj and o_proj scale_bias shape is [output_size, 16],
|
||||
# others are [output_size, 1]
|
||||
layer_type = "row" if isinstance(layer,
|
||||
RowParallelLinear) else "others"
|
||||
|
||||
pergroup_dict = self.quant_method.get_pergroup_param(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype,
|
||||
layer_type=layer_type)
|
||||
for pergroup_name, pergroup_param in pergroup_dict.items():
|
||||
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(pergroup_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
|
||||
setattr(param, "input_dim", 1)
|
||||
param.input_dim = 1
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
|
||||
tp_rank = get_otp_group().rank_in_group
|
||||
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
|
||||
tp_rank = get_mlp_tp_group().rank_in_group
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
else:
|
||||
tp_rank = 0
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
|
||||
|
||||
class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
"""KVCache method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix, "attention")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
# Different from linear method, there are no weight processing/slicing
|
||||
# steps for attention in vllm. So the whole process of create weights
|
||||
# is hidden into the specific quant method.
|
||||
self.quant_method.create_weights(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
attn_metadata, attn_type, scale, output)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
"""FusedMoE method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix, "moe",
|
||||
packed_modules_mapping)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
weight_param = self.quant_method.get_weight(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in weight_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
per_group_param = [
|
||||
"weight_scale_second", "weight_offset_second", "scale_bias"
|
||||
]
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in dynamic_quant_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if any(fields in param_key for fields in per_group_param):
|
||||
setattr(param, "quant_method",
|
||||
FusedMoeWeightScaleSupported.GROUP.value)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num=0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, log2phy,
|
||||
global_redundant_expert_num, **kwargs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
# TODO: implement this function
|
||||
pass
|
||||
|
||||
|
||||
class AscendEmbeddingMethod(AscendLinearMethod):
|
||||
"""Embedding method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix, "linear",
|
||||
packed_modules_mapping)
|
||||
98
vllm_npu/quantization/utils.py
Normal file
98
vllm_npu/quantization/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A8_DYNAMIC": {
|
||||
"linear": AscendW4A8DynamicLinearMethod,
|
||||
"moe": AscendW4A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"W4A4_FLATQUANT_DYNAMIC": {
|
||||
"linear": AscendW4A4FlatQuantDynamicLinearMethod,
|
||||
},
|
||||
"W8A8": {
|
||||
"linear": AscendW8A8LinearMethod,
|
||||
"moe": AscendW8A8FusedMoEMethod,
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
"W8A8_DYNAMIC": {
|
||||
"linear": AscendW8A8DynamicLinearMethod,
|
||||
"moe": AscendW8A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"C8": {
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in packed_modules_mapping:
|
||||
quant_type = None
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in packed_modules_mapping[proj_name]
|
||||
]
|
||||
for shard_prefix in shard_prefixes:
|
||||
shard_quant_type = quant_description[shard_prefix + '.weight']
|
||||
|
||||
if quant_type is None:
|
||||
quant_type = shard_quant_type
|
||||
elif shard_quant_type != quant_type:
|
||||
raise ValueError(
|
||||
f"Not all shards of {prefix} are quantized with same quant type."
|
||||
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
||||
f"use {quant_type}. Please check quantization config.")
|
||||
elif "experts" in prefix:
|
||||
# For the experts' prefix (e.g., "model.layers.3.mlp.experts")
|
||||
# Assume all experts within the same MLP use the same quantization method
|
||||
experts_quant_description = set(quant_description[layer]
|
||||
for layer in quant_description
|
||||
if prefix in layer)
|
||||
if not len(experts_quant_description) == 1:
|
||||
raise RuntimeError(
|
||||
f"{prefix} has different quantization type: {experts_quant_description}."
|
||||
)
|
||||
quant_type = experts_quant_description.pop()
|
||||
else:
|
||||
quant_type = quant_description[prefix + '.weight']
|
||||
return quant_type
|
||||
|
||||
|
||||
def get_quant_method(quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
||||
logger.info_once("Using the vLLM Ascend Quantization now!")
|
||||
if packed_modules_mapping is None:
|
||||
packed_modules_mapping = dict()
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
# Use KVCache int8
|
||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
quant_type = get_linear_quant_type(quant_description, prefix,
|
||||
packed_modules_mapping)
|
||||
if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys():
|
||||
method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type]
|
||||
if layer_type in method_map.keys():
|
||||
method_cls = method_map[layer_type]
|
||||
return method_cls()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
|
||||
)
|
||||
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
||||
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
|
||||
193
vllm_npu/quantization/w4a4_flatquant_dynamic.py
Normal file
193
vllm_npu/quantization/w4a4_flatquant_dynamic.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
KRONECKER_QUANT_MAX_BATCH_SIZE = 32768
|
||||
|
||||
|
||||
def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
|
||||
original_device = weight_tensor.device
|
||||
weight_tensor_npu = weight_tensor.npu()
|
||||
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(
|
||||
weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
|
||||
return weight_int4_packed.to(original_device)
|
||||
|
||||
|
||||
def get_decompose_dim(n):
|
||||
a = int(math.sqrt(n))
|
||||
if a * a < n:
|
||||
a += 1
|
||||
while True:
|
||||
tmp = a * a - n
|
||||
b = int(math.sqrt(tmp))
|
||||
if b * b == tmp:
|
||||
break
|
||||
a += 1
|
||||
return a - b, a + b
|
||||
|
||||
|
||||
# TODO: This function is a temporary workaround for the npu_kronecker_quant operator,
|
||||
# which has a limitation on the maximum batch size (dim0). This wrapper should be
|
||||
# removed once the operator supports larger inputs natively.
|
||||
def batched_kronecker_quant(
|
||||
x: torch.Tensor,
|
||||
left_trans: torch.Tensor,
|
||||
right_trans: torch.Tensor,
|
||||
clip_ratio: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_tokens = x.shape[0]
|
||||
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
|
||||
return torch_npu.npu_kronecker_quant(x,
|
||||
left_trans,
|
||||
right_trans,
|
||||
clip_ratio=clip_ratio,
|
||||
dst_dtype=torch.int32)
|
||||
x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0)
|
||||
processed_chunks = [
|
||||
torch_npu.npu_kronecker_quant(chunk,
|
||||
left_trans,
|
||||
right_trans,
|
||||
clip_ratio=clip_ratio,
|
||||
dst_dtype=torch.int32)
|
||||
for chunk in x_chunks
|
||||
]
|
||||
quantized_list, scale_list = zip(*processed_chunks)
|
||||
x_quantized_int4 = torch.cat(quantized_list, dim=0)
|
||||
activation_scale = torch.cat(scale_list, dim=0)
|
||||
return x_quantized_int4, activation_scale
|
||||
|
||||
|
||||
class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A4_FLATQUANT_DYNAMIC.
|
||||
|
||||
This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization.
|
||||
- Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading
|
||||
- Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for distribution smoothing
|
||||
- Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights
|
||||
"""
|
||||
input_size = 0
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = False
|
||||
self.sym = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if input_size % 8 != 0:
|
||||
raise ValueError(
|
||||
f"input_size ({input_size}) must be divisible by 8 for int4 packing"
|
||||
)
|
||||
AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
left_trans_dim, right_trans_dim = get_decompose_dim(
|
||||
AscendW4A4FlatQuantDynamicLinearMethod.input_size)
|
||||
params_dict["left_trans"] = torch.empty(left_trans_dim,
|
||||
left_trans_dim,
|
||||
dtype=params_dtype)
|
||||
params_dict["right_trans"] = torch.empty(right_trans_dim,
|
||||
right_trans_dim,
|
||||
dtype=params_dtype)
|
||||
params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
input_shape = x.shape
|
||||
in_features = input_shape[-1]
|
||||
left_dim = layer.left_trans.shape[0]
|
||||
right_dim = layer.right_trans.shape[0]
|
||||
if left_dim * right_dim != in_features:
|
||||
raise ValueError(
|
||||
f"FlatQuant transform matrices dimension mismatch: "
|
||||
f"left_dim({left_dim}) * right_dim({right_dim}) != in_features({in_features})"
|
||||
)
|
||||
left_trans_matched = layer.left_trans.to(original_dtype)
|
||||
right_trans_matched = layer.right_trans.to(original_dtype)
|
||||
x_reshaped = x.view(-1, left_dim, right_dim)
|
||||
x_quantized_int4, activation_scale = batched_kronecker_quant(
|
||||
x_reshaped, left_trans_matched, right_trans_matched,
|
||||
layer.aclnn_clip_ratio)
|
||||
x_quantized_reshaped = x_quantized_int4.view(-1,
|
||||
left_dim * right_dim // 8)
|
||||
pertoken_scale = activation_scale.view(-1).to(torch.float32)
|
||||
output = torch_npu.npu_quant_matmul(x_quantized_reshaped,
|
||||
layer.weight_packed.t(),
|
||||
layer.weight_scale.view(-1).to(
|
||||
torch.float32),
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype)
|
||||
output = output.view(*input_shape[:-1], -1)
|
||||
if bias is not None:
|
||||
output = output + bias.to(original_dtype)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
weight_packed = pack_int4_weights(layer.weight.data)
|
||||
if self.transpose_weight:
|
||||
weight_packed = weight_packed.transpose(0, 1).contiguous()
|
||||
layer.register_parameter(
|
||||
'weight_packed',
|
||||
torch.nn.Parameter(weight_packed, requires_grad=False))
|
||||
del layer.weight
|
||||
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.to(torch.float32)
|
||||
layer.left_trans = torch.nn.Parameter(
|
||||
layer.left_trans.data.t().contiguous())
|
||||
layer.right_trans = torch.nn.Parameter(layer.right_trans.data)
|
||||
layer.clip_ratio = torch.nn.Parameter(
|
||||
layer.clip_ratio.data.to(torch.float32))
|
||||
layer.aclnn_clip_ratio = layer.clip_ratio.item()
|
||||
490
vllm_npu/quantization/w4a8_dynamic.py
Normal file
490
vllm_npu/quantization/w4a8_dynamic.py
Normal file
@@ -0,0 +1,490 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.distributed.parallel_state import get_mc2_group
|
||||
from vllm_npu.ops.moe.experts_selector import select_experts
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Create weight parameters.
|
||||
|
||||
For new quantization version (double int4 pack into int8), the output dimension
|
||||
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
|
||||
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
|
||||
"""
|
||||
params_dict = {}
|
||||
|
||||
if self.new_quant_version:
|
||||
# double int4 pack into int8: output dimension is compressed
|
||||
pack_factor = 2
|
||||
actual_output_size = output_size // pack_factor
|
||||
params_dict["weight"] = torch.empty(actual_output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
# Add packing information for vLLM's weight_loader
|
||||
params_dict["_packed_dim"] = 0
|
||||
params_dict["_packed_factor"] = pack_factor
|
||||
else:
|
||||
params_dict["weight"] = torch.empty(output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Create per-group quantization parameters.
|
||||
"""
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
# NOTE: In w4a8 quantization implementation,
|
||||
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
|
||||
# others are [output_size, 1]
|
||||
if self.new_quant_version:
|
||||
scale_bias_dim = 16 if layer_type == "row" else 1
|
||||
|
||||
params_dict["scale_bias"] = torch.empty(output_size,
|
||||
scale_bias_dim,
|
||||
dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor,
|
||||
is_new_quant: bool = False):
|
||||
"""
|
||||
Process the scale for second-level quantization.
|
||||
|
||||
Args:
|
||||
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
|
||||
scale: first-level quantization scale [output_size]
|
||||
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
|
||||
is_new_quant: whether it's the new quantization version (weight already compressed)
|
||||
|
||||
Returns:
|
||||
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
|
||||
"""
|
||||
k, n = weight.shape
|
||||
group_num, n_scale = per_group_scale.shape
|
||||
|
||||
if is_new_quant:
|
||||
# Restore logical dimension for compressed weight
|
||||
n = n * 2
|
||||
|
||||
bias = None
|
||||
if not is_new_quant:
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
# NOTE: scale_bias is not used currently
|
||||
# because in msmodelslim w4a8 uses symmetric quantization
|
||||
|
||||
# TODO: support potential future asymmetric quantization
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
is_new_quant=self.new_quant_version,
|
||||
)
|
||||
|
||||
if self.new_quant_version:
|
||||
# Process the loaded data based on layer type
|
||||
if hasattr(layer, "scale_bias"):
|
||||
if layer.scale_bias.data.shape[1] == 1:
|
||||
layer.scale_bias.data = layer.scale_bias.data.flatten()
|
||||
else:
|
||||
layer.scale_bias.data = layer.scale_bias.data.contiguous()
|
||||
else:
|
||||
if scale_bias is not None:
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
|
||||
# Convert to NPU-specific int4pack format
|
||||
if self.new_quant_version:
|
||||
# weights on disk are already in packed int4 format
|
||||
# pack 4 int8(int4*2) to int32
|
||||
assert layer.weight.data.shape[-1] % 4 == 0, \
|
||||
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
|
||||
layer.weight.data = layer.weight.data.view(
|
||||
torch.int32).contiguous()
|
||||
else:
|
||||
# weights are not compressed
|
||||
# need to be packed via npu_convert_weight_to_int4pack
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class AscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes // 2
|
||||
else:
|
||||
w13_output_size = 2 * intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
w13_output_size,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
w2_output_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
16 // self.tp_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(
|
||||
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
n = n * 2
|
||||
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = per_group_scale.shape
|
||||
bias = None
|
||||
if not self.new_quant_version:
|
||||
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
weight_high = weight_high.reshape([group_num, k, n])
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||
torch.float32)
|
||||
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||
scale_fp32_np.dtype = np.uint32
|
||||
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||
dtype=np.uint32)
|
||||
|
||||
sscale_uint64[..., ::2] = scale_fp32_np
|
||||
|
||||
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||
dtype=np.int64).copy()
|
||||
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||
group_num, quantgroup_num, n)
|
||||
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||
return sscale_uint64_tensor, bias
|
||||
|
||||
def update_bias(self, layer, w13_bias, w2_bias):
|
||||
if self.new_quant_version:
|
||||
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
else:
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
674
vllm_npu/quantization/w8a8.py
Normal file
674
vllm_npu/quantization/w8a8.py
Normal file
@@ -0,0 +1,674 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_npu.ops.moe.experts_selector import select_experts
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor,
|
||||
function=False):
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
||||
torch.qint8, -1, function)
|
||||
|
||||
|
||||
class AscendW8A8LinearMethod:
|
||||
"""Linear method for Ascend W8A8.
|
||||
|
||||
Args:
|
||||
w_sym: whether the linear weight is symmetrically quantized.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||
self.transpose_weight = not is_310p()
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||
if params_dtype == torch.bfloat16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size,
|
||||
dtype=torch.float32)
|
||||
elif params_dtype == torch.float16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size,
|
||||
dtype=torch.int64)
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
try:
|
||||
weight_prefetch_method = get_forward_context(
|
||||
).weight_prefetch_method
|
||||
except AssertionError:
|
||||
weight_prefetch_method = None
|
||||
|
||||
# prefetch qkvo_proj.weight preprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
# prefetch qkvo_proj.weight postprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
stop_flag=x,
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if is_310p():
|
||||
# On 300I Duo platform, we need transpose again if
|
||||
# using nz. This transpose can be skipped in torchair.
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight.data.transpose(1, 0),
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
else:
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
|
||||
class AscendW8A8FusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
2 *
|
||||
intermediate_size_per_partition,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["quant_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if is_310p():
|
||||
return fused_experts_310p(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
||||
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
||||
|
||||
if is_310p():
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.max())
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.max())
|
||||
else:
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1,
|
||||
expanding_factor_w13)[0:1])
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
layer.w13_input_offset.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
||||
layer.w2_input_offset.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.antiquant_scale_comb = None
|
||||
|
||||
@staticmethod
|
||||
def create_weights(layer) -> None:
|
||||
param_dict = {} # num_kv_heads * head_size
|
||||
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
for weight_name, weight_param in param_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
layer.register_parameter(weight_name, param)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.antiquant_scale_comb = torch.cat(
|
||||
(layer.key_antiquant_scale.data.unsqueeze(0),
|
||||
layer.value_antiquant_scale.data.unsqueeze(0)),
|
||||
dim=0).to(torch.float16).contiguous()
|
||||
|
||||
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
# C8
|
||||
quant_key = quant_per_tensor(
|
||||
key.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.key_antiquant_scale.data.view(-1), None, True)
|
||||
quant_value = quant_per_tensor(
|
||||
value.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.value_antiquant_scale.data.view(-1), None, True)
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, layer.num_heads, layer.head_size)
|
||||
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
# if key_cache is None:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
|
||||
# C8
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
||||
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=scale,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
out=output.reshape(query.shape))
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"PrefillCacheHit")
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
if hasattr(attn_metadata, "decode"):
|
||||
# torch_air
|
||||
decode_meta = attn_metadata.decode
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
else:
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1, layer.num_heads *
|
||||
layer.head_size).contiguous() # changed
|
||||
|
||||
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
||||
key = key_cache
|
||||
value = value_cache
|
||||
|
||||
output = torch_npu.npu_incre_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_key_value_heads=layer.num_kv_heads,
|
||||
num_heads=layer.num_heads,
|
||||
actual_seq_lengths=seq_lens,
|
||||
scale_value=scale,
|
||||
input_layout='BSH',
|
||||
block_size=block_size,
|
||||
block_table=attn_metadata.block_tables,
|
||||
antiquant_scale=self.antiquant_scale_comb,
|
||||
)
|
||||
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"other case")
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts_310p(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=sorted_hidden_states,
|
||||
quantized_weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w1_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=gate_up_out,
|
||||
quantized_weight=w2,
|
||||
weight_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w2_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w1_input_offset: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
w2_input_offset: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused experts with top-k routing.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
|
||||
original_dtype = hidden_states.dtype
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
hidden_states,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
quant_sorted_hidden_states,
|
||||
topk_ids,
|
||||
scale=None,
|
||||
active_num=topk_ids.numel(),
|
||||
expert_capacity=-1,
|
||||
expert_num=local_num_experts,
|
||||
drop_pad_mode=0,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
quant_mode=-1,
|
||||
active_expert_range=[0, local_num_experts],
|
||||
row_idx_type=0,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
if expanded_x.dtype != w1.dtype:
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
expanded_x,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_sorted_hidden_states = expanded_x
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_sorted_hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale * w1_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if gate_up_out.dtype != w2.dtype:
|
||||
w2_input_scale, _ = w2_input_scale.max(0)
|
||||
quant_gate_up_out = quant_per_tensor(
|
||||
gate_up_out,
|
||||
w2_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_gate_up_out = gate_up_out
|
||||
|
||||
down_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_gate_up_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale * w2_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights.to(down_out.dtype),
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
drop_pad_mode=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
|
||||
return final_hidden_states
|
||||
284
vllm_npu/quantization/w8a8_dynamic.py
Normal file
284
vllm_npu/quantization/w8a8_dynamic.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.distributed.parallel_state import get_mc2_group
|
||||
from vllm_npu.ops.moe.experts_selector import select_experts
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendW8A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W8A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
config = getattr(layer, "_ascend_quant_config", {})
|
||||
if not isinstance(x, tuple):
|
||||
output_dtype = config.get("output_dtype", x.dtype)
|
||||
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
else:
|
||||
assert "output_dtype" in config.keys(), (
|
||||
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
||||
f"for pre-quantized input, got config [{config}]")
|
||||
output_dtype = config["output_dtype"]
|
||||
quantized_x, dynamic_scale = x
|
||||
pertoken_scale = (dynamic_scale
|
||||
if config.get("pertoken_scale", True) else None)
|
||||
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=bias,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
return ((output, dynamic_scale)
|
||||
if config.get("return_scale", False) else output)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format for higher inference speed
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
|
||||
|
||||
class AscendW8A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_aclgraph = (
|
||||
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and not ascend_config.torchair_graph_config.enabled)
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
2 *
|
||||
intermediate_size_per_partition,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(
|
||||
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||
|
||||
if self.use_aclgraph:
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
expert_map=expert_map,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale.to(torch.float32),
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
torch.float32)
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
Reference in New Issue
Block a user