mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
98
vllm_npu/quantization/utils.py
Normal file
98
vllm_npu/quantization/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A8_DYNAMIC": {
|
||||
"linear": AscendW4A8DynamicLinearMethod,
|
||||
"moe": AscendW4A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"W4A4_FLATQUANT_DYNAMIC": {
|
||||
"linear": AscendW4A4FlatQuantDynamicLinearMethod,
|
||||
},
|
||||
"W8A8": {
|
||||
"linear": AscendW8A8LinearMethod,
|
||||
"moe": AscendW8A8FusedMoEMethod,
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
"W8A8_DYNAMIC": {
|
||||
"linear": AscendW8A8DynamicLinearMethod,
|
||||
"moe": AscendW8A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"C8": {
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in packed_modules_mapping:
|
||||
quant_type = None
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in packed_modules_mapping[proj_name]
|
||||
]
|
||||
for shard_prefix in shard_prefixes:
|
||||
shard_quant_type = quant_description[shard_prefix + '.weight']
|
||||
|
||||
if quant_type is None:
|
||||
quant_type = shard_quant_type
|
||||
elif shard_quant_type != quant_type:
|
||||
raise ValueError(
|
||||
f"Not all shards of {prefix} are quantized with same quant type."
|
||||
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
||||
f"use {quant_type}. Please check quantization config.")
|
||||
elif "experts" in prefix:
|
||||
# For the experts' prefix (e.g., "model.layers.3.mlp.experts")
|
||||
# Assume all experts within the same MLP use the same quantization method
|
||||
experts_quant_description = set(quant_description[layer]
|
||||
for layer in quant_description
|
||||
if prefix in layer)
|
||||
if not len(experts_quant_description) == 1:
|
||||
raise RuntimeError(
|
||||
f"{prefix} has different quantization type: {experts_quant_description}."
|
||||
)
|
||||
quant_type = experts_quant_description.pop()
|
||||
else:
|
||||
quant_type = quant_description[prefix + '.weight']
|
||||
return quant_type
|
||||
|
||||
|
||||
def get_quant_method(quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
||||
logger.info_once("Using the vLLM Ascend Quantization now!")
|
||||
if packed_modules_mapping is None:
|
||||
packed_modules_mapping = dict()
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
# Use KVCache int8
|
||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
quant_type = get_linear_quant_type(quant_description, prefix,
|
||||
packed_modules_mapping)
|
||||
if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys():
|
||||
method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type]
|
||||
if layer_type in method_map.keys():
|
||||
method_cls = method_map[layer_type]
|
||||
return method_cls()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
|
||||
)
|
||||
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
||||
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
|
||||
Reference in New Issue
Block a user