mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
大改
This commit is contained in:
4
setup.py
4
setup.py
@@ -20,5 +20,9 @@ setup(
|
||||
"vllm.platform_plugins": [
|
||||
"npu = vllm_npu:register",
|
||||
],
|
||||
"vllm.general_plugins": [
|
||||
"npu_enhanced_model = vllm_npu:register_model",
|
||||
"npu_kv_connector = vllm_npu:register_connector",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,38 +1,33 @@
|
||||
"""
|
||||
vllm_npu — Ascend NPU platform plugin for vLLM.
|
||||
|
||||
The ``register()`` function is discovered by vLLM through the
|
||||
``vllm.platform_plugins`` entry-point and returns the fully-qualified
|
||||
class name of the platform implementation.
|
||||
"""
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
|
||||
def register():
|
||||
"""Return the fully-qualified name of the NPU platform class."""
|
||||
# Apply CUDA→NPU compatibility patches early so that any code
|
||||
# referencing torch.cuda.Stream / Event / etc. will transparently
|
||||
# be redirected to the torch.npu equivalents.
|
||||
from vllm_npu.cuda_compat import _patch_cuda_to_npu
|
||||
_patch_cuda_to_npu()
|
||||
"""Register the NPU platform."""
|
||||
|
||||
return "vllm_npu.platform.NPUPlatform"
|
||||
|
||||
|
||||
def register_npu_ops():
|
||||
"""Register Ascend NPU op overrides with vLLM's CustomOp system.
|
||||
def register_model():
|
||||
|
||||
Must be called AFTER the platform is established (e.g., during
|
||||
worker init or check_and_update_config), NOT during register().
|
||||
"""
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from .models import register_model
|
||||
register_model()
|
||||
|
||||
from vllm_npu.ops.activation import AscendSiluAndMul
|
||||
from vllm_npu.ops.layernorm import AscendRMSNorm
|
||||
from vllm_npu.ops.rotary_embedding import AscendRotaryEmbedding
|
||||
|
||||
for name, op_cls in {
|
||||
"SiluAndMul": AscendSiluAndMul,
|
||||
"RMSNorm": AscendRMSNorm,
|
||||
"RotaryEmbedding": AscendRotaryEmbedding,
|
||||
}.items():
|
||||
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
||||
def register_connector():
|
||||
from vllm_npu.distributed import register_connector
|
||||
register_connector()
|
||||
|
||||
310
vllm_npu/ascend_config.py
Normal file
310
vllm_npu/ascend_config.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
|
||||
|
||||
|
||||
def _check_torchair_supported(model_type: str):
|
||||
for supported_model in TORCHAIR_MODEL_LIST:
|
||||
if supported_model in model_type.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AscendConfig:
|
||||
"""
|
||||
Configuration Object for additional_config from vllm.configs.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
self.torchair_graph_config = TorchairGraphConfig(
|
||||
torchair_graph_config, vllm_config, additional_config)
|
||||
|
||||
ascend_scheduler_config = additional_config.get(
|
||||
"ascend_scheduler_config", {})
|
||||
self.ascend_scheduler_config = AscendSchedulerConfig(
|
||||
ascend_scheduler_config)
|
||||
|
||||
weight_prefetch_config = additional_config.get(
|
||||
"weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||
weight_prefetch_config)
|
||||
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||
self.eplb_policy_type = additional_config.get("eplb_policy_type", 1)
|
||||
self.expert_map_record_path = additional_config.get(
|
||||
"expert_map_record_path",
|
||||
None) # Provide path to export expert map
|
||||
self.init_redundancy_expert = additional_config.get(
|
||||
"init_redundancy_expert", 0)
|
||||
self.dynamic_eplb = additional_config.get("dynamic_eplb", False)
|
||||
self.num_iterations_eplb_update = additional_config.get(
|
||||
"num_iterations_eplb_update", 400)
|
||||
self.gate_eplb = additional_config.get("gate_eplb", False)
|
||||
self.num_wait_worker_iterations = additional_config.get(
|
||||
"num_wait_worker_iterations", 30)
|
||||
self.chunked_prefill_for_mla = additional_config.get(
|
||||
"chunked_prefill_for_mla", False)
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp", False
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.recompute_scheduler_enable = additional_config.get(
|
||||
"recompute_scheduler_enable", False)
|
||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||
"lmhead_tensor_parallel_size", None)
|
||||
if self.lmhead_tensor_parallel_size is not None:
|
||||
logger.info(
|
||||
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
|
||||
)
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise AssertionError(
|
||||
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
|
||||
)
|
||||
self.oproj_tensor_parallel_size = additional_config.get(
|
||||
"oproj_tensor_parallel_size", None)
|
||||
if self.oproj_tensor_parallel_size is not None:
|
||||
logger.info(
|
||||
f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario"
|
||||
)
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in the pure DP scenario"
|
||||
)
|
||||
if not self.torchair_graph_config.enabled:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in graph mode"
|
||||
)
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
self.enable_cpu_binding = additional_config.get(
|
||||
"enable_cpu_binding", False)
|
||||
self.pd_tp_ratio = 1
|
||||
self.pd_head_ratio = 1
|
||||
self.num_head_replica = 1
|
||||
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
|
||||
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {"tp_size": 1})["tp_size"]
|
||||
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"decode", {"tp_size": 1})["tp_size"]
|
||||
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
|
||||
self.pd_tp_ratio = prefill_tp_size // decode_tp_size
|
||||
if self.pd_tp_ratio > 1:
|
||||
try:
|
||||
# only support Qwen model now
|
||||
# TODO: use a more robust method to get kv_head_num
|
||||
num_kv_head = vllm_config.model_config.hf_config.num_key_value_heads
|
||||
self.num_head_replica = prefill_tp_size // num_kv_head if prefill_tp_size >= num_kv_head else 1
|
||||
prefill_tp_size = min(prefill_tp_size, num_kv_head)
|
||||
decode_tp_size = min(decode_tp_size, num_kv_head)
|
||||
self.pd_head_ratio = prefill_tp_size // decode_tp_size
|
||||
except Exception:
|
||||
raise AssertionError(
|
||||
"Can not get num_key_value_heads from model_config")
|
||||
|
||||
if self.pd_tp_ratio == 0:
|
||||
raise AssertionError(
|
||||
"Only support P node tp size lagger then D node tp size")
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
"""
|
||||
Configuration Object for torchair_graph_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, torchair_graph_config, vllm_config, additional_config):
|
||||
self.enabled = torchair_graph_config.get("enabled", False)
|
||||
self.mode = torchair_graph_config.get("mode", '')
|
||||
self.use_cached_graph = torchair_graph_config.get(
|
||||
"use_cached_graph", False)
|
||||
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
|
||||
"use_cached_kv_cache_bytes", False)
|
||||
self.graph_batch_sizes = torchair_graph_config.get(
|
||||
"graph_batch_sizes", [])
|
||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_mla = torchair_graph_config.get(
|
||||
"enable_multistream_mla", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||
"enable_frozen_parameter", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
self.enable_super_kernel = torchair_graph_config.get(
|
||||
"enable_super_kernel", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
raise TypeError("graph_batch_sizes must be list[int]")
|
||||
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
|
||||
raise ValueError(
|
||||
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
|
||||
)
|
||||
if not self.enabled:
|
||||
if self.mode:
|
||||
raise RuntimeError(
|
||||
"mode is valid only when Torchair graph mode is enabled")
|
||||
if self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes_init:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_mla:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_kv_nz:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when tensor_parallel_size is 1"
|
||||
)
|
||||
if not additional_config.get("multistream_overlap_shared_expert",
|
||||
False):
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||
)
|
||||
|
||||
|
||||
class AscendSchedulerConfig:
|
||||
"""
|
||||
Configuration Object for ascend_scheduler_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, ascend_scheduler_config: dict):
|
||||
self.enabled = ascend_scheduler_config.get("enabled", False)
|
||||
# Ascend scheduler is based on vllm v0 scheduler, so we should support
|
||||
# all vllm v0 scheduler configs as well.
|
||||
for k, v in ascend_scheduler_config.items():
|
||||
if not hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class WeightPrefetchConfig:
|
||||
"""
|
||||
Configuration Object for weight_prefetch_config from additional_config
|
||||
"""
|
||||
|
||||
prefetch_ratio: dict = {
|
||||
"attn": {
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
"moe": {
|
||||
"gate_up": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||
self.prefetch_ratio = weight_prefetch_config.get(
|
||||
"prefetch_ratio", self.prefetch_ratio)
|
||||
|
||||
|
||||
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
||||
|
||||
|
||||
def init_ascend_config(vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
refresh = additional_config.get("refresh",
|
||||
False) if additional_config else False
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is not None and not refresh:
|
||||
return _ASCEND_CONFIG
|
||||
_ASCEND_CONFIG = AscendConfig(vllm_config)
|
||||
return _ASCEND_CONFIG
|
||||
|
||||
|
||||
def clear_ascend_config():
|
||||
global _ASCEND_CONFIG
|
||||
_ASCEND_CONFIG = None
|
||||
|
||||
|
||||
def get_ascend_config():
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is None:
|
||||
raise RuntimeError(
|
||||
"Ascend config is not initialized. Please call init_ascend_config first."
|
||||
)
|
||||
return _ASCEND_CONFIG
|
||||
|
||||
|
||||
def check_ascend_config(vllm_config, enforce_eager):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
# for eager mode
|
||||
if enforce_eager:
|
||||
# torchair_graph cannot be enabled with eager mode.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
raise RuntimeError(
|
||||
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
||||
)
|
||||
# for graph mode
|
||||
else:
|
||||
# torchair_graph case
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# torchair_graph is supported for deepseek/pangu/qwen model only.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if not _check_torchair_supported(model_type):
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode only works with following model types:"
|
||||
f"{TORCHAIR_MODEL_LIST}.")
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
logger.warning(
|
||||
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
|
||||
"it has been disabled automatically.")
|
||||
# aclgraph case
|
||||
else:
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if "qwen" not in model_type:
|
||||
logger.warning(
|
||||
"ACL Graph is currently experimental. Please "
|
||||
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
||||
" if you encourage any Error")
|
||||
211
vllm_npu/ascend_forward_context.py
Normal file
211
vllm_npu/ascend_forward_context.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.utils import enable_sp, has_layer_idx, is_moe_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_npu.ops.weight_prefetch import WeightPrefetchMethod
|
||||
else:
|
||||
WeightPrefetchMethod = None
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
All2All = 1
|
||||
MC2 = 2
|
||||
AllGatherEP = 3
|
||||
NaiveMulticast = 4
|
||||
All2AllSeq = 5
|
||||
|
||||
|
||||
class MoECommType(Enum):
|
||||
ALLGATHER = 0
|
||||
MC2 = 1
|
||||
ALLTOALL = 2
|
||||
NAIVE_MULTICAST = 3
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return FusedMoEState.AllGatherEP
|
||||
elif ep_size == 1:
|
||||
if with_prefill:
|
||||
return FusedMoEState.NaiveMulticast
|
||||
else:
|
||||
return FusedMoEState.AllGather
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
return FusedMoEState.All2All
|
||||
else:
|
||||
return FusedMoEState.MC2
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_ascend_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||
moe_comm_type: Optional[MoECommType] = None,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
prefetch_stream: torch.npu.Stream = None,
|
||||
model_instance: torch.nn.Module = None,
|
||||
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
"""
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config,
|
||||
virtual_engine=virtual_engine,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
from vllm_npu.ops.moe.moe_comm_method import get_moe_comm_method
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
forward_context.with_prefill = with_prefill
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
is_deepseek_v3_r1 = hasattr(
|
||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
forward_context.fused_moe_state = fused_moe_state
|
||||
forward_context.in_profile_run = in_profile_run
|
||||
|
||||
# NOTE: This cannot be set using set_forward_context
|
||||
# due to multiple warmups before actual capturing
|
||||
forward_context.capturing = False
|
||||
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of communication methods.
|
||||
mmrs_fusion = True
|
||||
if is_moe_model(vllm_config):
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
tp_world_size > 1 and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
tp_world_size > 1 and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
|
||||
if sp_enabled:
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
forward_context.num_tokens = num_tokens
|
||||
|
||||
# set this for rope forward_oot using
|
||||
forward_context.is_first_layer = True
|
||||
|
||||
# set layer_idx to enable optimization features that depend on this information.
|
||||
# This is only applicable to models that contain these necessary attributes.
|
||||
forward_context.layer_idx = None
|
||||
if has_layer_idx(model_instance):
|
||||
forward_context.layer_idx = model_instance.model.start_layer
|
||||
|
||||
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
||||
# set for mlp weight prefetch
|
||||
prefetch_mlp_enabled = envs_ascend.vllm_npu_ENABLE_DENSE_OPTIMIZE and \
|
||||
envs_ascend.vllm_npu_ENABLE_PREFETCH_MLP and \
|
||||
forward_context.layer_idx is not None and \
|
||||
num_tokens is not None and num_tokens < 500
|
||||
if prefetch_mlp_enabled:
|
||||
forward_context.prefetch_stream = prefetch_stream
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
# It will be improved later by implementing operator fusion through the FX graph.
|
||||
#
|
||||
# set for addrmsnorm+quant fusion.
|
||||
# this optim now just support dense models due to the specific operators used.
|
||||
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||
from vllm_npu.quantization.quant_config import AscendQuantConfig
|
||||
model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"]
|
||||
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||
vllm_config.model_config.hf_config.model_type in model_type_scope and \
|
||||
forward_context.layer_idx is not None
|
||||
if addrmsnorm_quant_fusion_enabled:
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
||||
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
|
||||
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
|
||||
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = \
|
||||
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
forward_context.padded_length = padded_length
|
||||
forward_context.pad_size = pad_size
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
||||
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||
padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
@@ -1 +0,0 @@
|
||||
"""Ascend NPU attention backends."""
|
||||
|
||||
96
vllm_npu/attention/attention_mask.py
Normal file
96
vllm_npu/attention/attention_mask.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def _generate_attn_mask(max_seq_len, dtype):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool).tril_()
|
||||
# Create upper triangle matrix used to mark mask positions.
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
mask_value = float('-inf') if dtype == torch.float16 else 1
|
||||
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
|
||||
.masked_fill_(mask_flag, mask_value)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device = None,
|
||||
):
|
||||
# NOTE: The device argument specifies the target NPU
|
||||
# to be used for the newly added FIA operator.
|
||||
# Only pass this parameter when using the new FIA operator.
|
||||
|
||||
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.device = device
|
||||
self.pooling_mask = None
|
||||
assigned_mask_dim = 2048
|
||||
self.chunked_prefill_attn_mask = torch.triu(
|
||||
torch.ones(assigned_mask_dim, assigned_mask_dim),
|
||||
diagonal=1).to(torch.int8).to(device)
|
||||
|
||||
@staticmethod
|
||||
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
||||
if dtype == torch.float16:
|
||||
mask_scale_factor = 1
|
||||
elif dtype == torch.bfloat16:
|
||||
mask_scale_factor = -10000
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current operation now only supports data types: torch.float16 and "
|
||||
"torch.bfloat16. Please ensure the input is of one of these types."
|
||||
)
|
||||
return mask_scale_factor
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
def get_pooling_mask(self, device):
|
||||
if self.pooling_mask is None:
|
||||
# the compressed attention mask for npu_fusion_attention sparse mode 4
|
||||
self.pooling_mask = torch.triu(torch.ones(
|
||||
2048, 2048), diagonal=1).to(torch.bool).to(device,
|
||||
non_blocking=True)
|
||||
return self.pooling_mask
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens: torch.Tensor = None,
|
||||
position: torch.Tensor = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
return self.chunked_prefill_attn_mask
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.dtype != dtype:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
||||
File diff suppressed because it is too large
Load Diff
1326
vllm_npu/attention/mla_v1.py
Normal file
1326
vllm_npu/attention/mla_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
988
vllm_npu/attention/sfa_v1.py
Normal file
988
vllm_npu/attention/sfa_v1.py
Normal file
@@ -0,0 +1,988 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_npu.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_npu.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_npu.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_npu.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_npu.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class AscendSFABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_SFA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AscendSFAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
return AscendSFAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendSFAImpl"]:
|
||||
return AscendSFAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
seq_lens: list[int]
|
||||
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFADecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
actual_seq_lengths_q: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
query_lens: Optional[list[int]] = None
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode: Optional[AscendSFADecodeMetadata] = None
|
||||
prefill: Optional[AscendSFAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
||||
# if self.head_dim is not None and self.head_dim \
|
||||
# not in supported_head_sizes:
|
||||
# raise ValueError(
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendSFAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendSFAMetadata # type: ignore
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
if num_tokens <= self.decode_threshold:
|
||||
decodes.append(i)
|
||||
else:
|
||||
prefills.append(i)
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
break
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
return modified_batch
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendSFAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=True)
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
model.model.start_layer].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
tokens_start = num_decode_tokens
|
||||
max_query_len = query_lens[reqs_start:].max().item()
|
||||
max_seq_lens = seq_lens[reqs_start:].max().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.block_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
chunked_context_metadata = \
|
||||
AscendSFAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
actual_query_lens = torch.tensor(query_lens[reqs_start:],
|
||||
dtype=torch.int32).npu()
|
||||
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
|
||||
dim=0).to(torch.int32)
|
||||
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
|
||||
prefill_metadata = AscendSFAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
query_lens=query_lens_prefill_sfa,
|
||||
seq_lens=seq_lens_prefill_sfa,
|
||||
context_lens=seq_lens[reqs_start:],
|
||||
input_positions=prefill_input_positions,
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
chunked_context=chunked_context_metadata,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
|
||||
torch.int32).npu()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendSFADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_lens=query_lens.tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
class PrefillSFAPreprocessResult(NamedTuple):
|
||||
q_nope: Optional[torch.Tensor] = None
|
||||
q_pe: Optional[torch.Tensor] = None
|
||||
k_nope: Optional[torch.Tensor] = None
|
||||
k_pe: Optional[torch.Tensor] = None
|
||||
topk_indices: Optional[torch.Tensor] = None
|
||||
query_states: Optional[torch.Tensor] = None
|
||||
key_states: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class DecodeSFAPreprocessResult(NamedTuple):
|
||||
q_nope: Optional[torch.Tensor] = None
|
||||
q_pe: Optional[torch.Tensor] = None
|
||||
# nope_cache: Optional[torch.Tensor] = None
|
||||
# rope_cache: Optional[torch.Tensor] = None
|
||||
topk_indices: Optional[torch.Tensor] = None
|
||||
query_states: Optional[torch.Tensor] = None
|
||||
key_states: Optional[torch.Tensor] = None
|
||||
bsz: Optional[int] = None
|
||||
|
||||
|
||||
class AscendSFAImpl(MLAAttentionImpl):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA Args
|
||||
self.q_lora_rank = kwargs['q_lora_rank']
|
||||
self.kv_lora_rank = kwargs['kv_lora_rank']
|
||||
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
||||
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
||||
self.qk_head_dim = kwargs['qk_head_dim']
|
||||
self.v_head_dim = kwargs['v_head_dim']
|
||||
self.rotary_emb = kwargs['rotary_emb']
|
||||
self.q_proj = kwargs['q_proj']
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.indexer = kwargs['indexer']
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_proj = kwargs.get('q_a_proj', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_rank = self.num_heads // self.tp_size
|
||||
if self.q_a_proj is not None:
|
||||
self.q_b_proj = self.q_proj
|
||||
else:
|
||||
self.q_b_proj = None
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
self.prefill_mask = None
|
||||
|
||||
# indexer param
|
||||
self.dim = self.indexer.dim
|
||||
self.n_heads: int = self.indexer.n_heads # 64
|
||||
self.head_dim: int = self.indexer.head_dim # 128
|
||||
self.index_topk: int = self.indexer.index_topk # 2048
|
||||
self.wq_b = self.indexer.wq_b
|
||||
self.wk = self.indexer.wk
|
||||
self.weights_proj = self.indexer.weights_proj
|
||||
self.k_norm = self.indexer.k_norm
|
||||
self.softmax_scale = self.indexer.softmax_scale
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
|
||||
self.cp_size = 1
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
|
||||
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv):
|
||||
# SFA Preprocess:
|
||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
# 3. If need_gather_q_kv, perform all_gather.
|
||||
# 4. Preprocess decode tokens, write kv cache and get:
|
||||
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||
# 5. Preprocess prefill tokens, write kv cache and get:
|
||||
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if need_gather_q_kv:
|
||||
# q_c = get_tp_group().all_gather(q_c, 0)
|
||||
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
# hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||
# if self.q_a_proj is not None:
|
||||
# npu_prefetch(self.q_a_proj.weight,
|
||||
# hidden_states,
|
||||
# enabled=self.enable_prefetch)
|
||||
# ckq = self.q_a_proj(hidden_states) # q down
|
||||
# q_c = self.q_a_layernorm(ckq) # q down layernorm
|
||||
# else:
|
||||
# q_c = hidden_states
|
||||
|
||||
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv
|
||||
# Process for shared_expert_dp
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
q_len = 1
|
||||
hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||
decode_kq = self.q_a_proj(hidden_states_decode) # q down
|
||||
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
|
||||
decode_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_decode) # c_kv
|
||||
|
||||
# decode_q_c = q_c[:num_decode_tokens]
|
||||
decode_slot_mapping = attn_metadata.slot_mapping[:
|
||||
num_decode_tokens]
|
||||
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||
|
||||
decode_q = self.q_b_proj(decode_q_c)
|
||||
bsz, _ = decode_q.shape
|
||||
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim)
|
||||
decode_q_nope, decode_q_pe = torch.split(
|
||||
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
decode_q_nope = decode_q_nope.view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||
decode_q_nope = (torch.matmul(decode_q_nope,
|
||||
self.kv_b_proj_w_k).transpose(
|
||||
1,
|
||||
0).view(bsz, q_len,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank))
|
||||
|
||||
# stream2 kv
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
|
||||
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
decode_kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
decode_slot_mapping.to(torch.int64),
|
||||
value_cache,
|
||||
key_cache,
|
||||
c_kv_scale=None,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode='PA') # adapter NZ
|
||||
# nz_block_size = 16
|
||||
# KVCACHE_NZ_DIM = 16
|
||||
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
|
||||
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
|
||||
|
||||
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
|
||||
sin) # BNSD
|
||||
|
||||
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
|
||||
self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
|
||||
topk_indices = self.indexer_select(hidden_states_decode,
|
||||
decode_q_c,
|
||||
attn_metadata=attn_metadata,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
kv_cache=kv_cache)
|
||||
|
||||
query_states = (decode_q_nope, decode_q_pe)
|
||||
key_states = (decode_k_nope, decode_k_rope)
|
||||
decode_preprocess_res = DecodeSFAPreprocessResult(
|
||||
q_nope=decode_q_nope,
|
||||
q_pe=decode_q_pe,
|
||||
# nope_cache = nope_cache,
|
||||
# rope_cache = rope_cache,
|
||||
topk_indices=topk_indices,
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
bsz=bsz,
|
||||
)
|
||||
|
||||
# Preprocess for prefill tokens
|
||||
if has_prefill:
|
||||
bsz = 1
|
||||
|
||||
hidden_states_prefill = hidden_states[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
|
||||
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
|
||||
prefill_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_prefill) # c_kv
|
||||
|
||||
# prefill_q_c = q_c[
|
||||
# num_decode_tokens:num_actual_tokens]
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
# prefill_kv_no_split = kv_no_split[
|
||||
# num_decode_tokens:num_actual_tokens]
|
||||
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
|
||||
prefill_qr = prefill_q_c
|
||||
prefill_q = self.q_b_proj(prefill_qr)
|
||||
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_nope, prefill_q_pe = torch.split(
|
||||
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
prefill_q_nope = prefill_q_nope.view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||
prefill_q_nope = (torch.matmul(prefill_q_nope,
|
||||
self.kv_b_proj_w_k).transpose(
|
||||
1,
|
||||
0).view(-1, self.num_heads,
|
||||
self.kv_lora_rank))
|
||||
prefill_q_pe = prefill_q_pe.unsqueeze(2)
|
||||
|
||||
# stream2 kv
|
||||
|
||||
nope_cache = kv_cache[0]
|
||||
rope_cache = kv_cache[1]
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
cos_q, sin_q = cos, sin
|
||||
|
||||
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
prefill_q_pe = torch_npu.npu_interleave_rope(
|
||||
prefill_q_pe, cos_q, sin_q) # BNSD
|
||||
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
|
||||
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
|
||||
|
||||
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
|
||||
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
prefill_latent_cache.view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
self.kv_a_layernorm.weight,
|
||||
cos.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||
sin.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||
prefill_slot_mapping.to(torch.int64),
|
||||
rope_cache,
|
||||
nope_cache,
|
||||
k_rope_scale=None,
|
||||
c_kv_scale=None,
|
||||
k_rope_offset=None,
|
||||
c_kv_offset=None,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA")
|
||||
|
||||
topk_indices = self.indexer_select(x=hidden_states_prefill,
|
||||
qr=prefill_qr,
|
||||
kv_cache=kv_cache,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
attn_metadata=attn_metadata)
|
||||
query_states = (prefill_q_nope, prefill_q_pe)
|
||||
key_states = (prefill_k_nope, prefill_k_pe)
|
||||
prefill_preprocess_res = PrefillSFAPreprocessResult(
|
||||
q_nope=prefill_q_nope,
|
||||
q_pe=prefill_q_pe,
|
||||
topk_indices=topk_indices,
|
||||
k_nope=prefill_k_nope,
|
||||
k_pe=prefill_k_pe,
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
)
|
||||
|
||||
return decode_preprocess_res, prefill_preprocess_res
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output = output[:num_actual_tokens, ...]
|
||||
o_proj_input_shape = (num_actual_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# SFA Preprocess
|
||||
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# bsz, q_len, _, _ = query_states[0].shape
|
||||
decode_attn_output = self.apply_attention_fusion(
|
||||
query_states=decode_preprocess_res.query_states,
|
||||
key_states=decode_preprocess_res.key_states,
|
||||
attn_metadata=attn_metadata,
|
||||
topk_indices=decode_preprocess_res.topk_indices)
|
||||
o_proj_input[:num_decode_tokens] = decode_attn_output
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
prefill_attn_output = self.apply_attention_fusion(
|
||||
query_states=prefill_preprocess_res.query_states,
|
||||
key_states=prefill_preprocess_res.key_states,
|
||||
attn_metadata=attn_metadata,
|
||||
topk_indices=prefill_preprocess_res.topk_indices)
|
||||
o_proj_input[num_decode_tokens:] = prefill_attn_output
|
||||
|
||||
output[...] = self.mla_epilog(o_proj_input, absorb=True)
|
||||
return output
|
||||
|
||||
def apply_attention_fusion(self, query_states, key_states, topk_indices,
|
||||
attn_metadata: M):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
q_nope, q_pe = query_states
|
||||
k_nope, k_rope = key_states
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=prefill_metadata.block_table,
|
||||
actual_seq_lengths_query=prefill_metadata.query_lens,
|
||||
actual_seq_lengths_kv=prefill_metadata.seq_lens,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
elif attn_metadata.decode is not None:
|
||||
decode_metadata = attn_metadata.decode
|
||||
|
||||
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=decode_metadata.seq_lens,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
slc_fa_fusion = slc_fa_fusion.squeeze(1)
|
||||
|
||||
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
|
||||
|
||||
# input shape [N//attn_tp_size, T(bs*q_len), D]
|
||||
# output shape [T(bs*q_len), N//attn_tp_size, D]
|
||||
attn_output = torch.matmul(slc_fa_fusion,
|
||||
self.kv_b_proj_w_v).transpose(1, 0).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
|
||||
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
|
||||
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
|
||||
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
return attn_output
|
||||
|
||||
def mla_epilog(self,
|
||||
attn_output: torch.Tensor = None,
|
||||
absorb: bool = False):
|
||||
# TODO: need to check
|
||||
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
|
||||
-1),
|
||||
is_prefill=True,
|
||||
is_force_scatter=False)
|
||||
|
||||
return attn_output
|
||||
|
||||
def indexer_select(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
cos,
|
||||
sin,
|
||||
attn_metadata: M,
|
||||
):
|
||||
if attn_metadata.prefill is not None:
|
||||
actual_seq_lengths_query = attn_metadata.prefill.query_lens
|
||||
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
|
||||
block_table = attn_metadata.prefill.block_table
|
||||
elif attn_metadata.decode is not None:
|
||||
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
|
||||
actual_seq_lengths_key = attn_metadata.decode.seq_lens
|
||||
block_table = attn_metadata.decode.block_table
|
||||
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
# q process in new stream
|
||||
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||
k = self.k_norm(k_proj).unsqueeze(1)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
-1, 1),
|
||||
k.view(-1,
|
||||
k.shape[-1])) # b, s, n, d
|
||||
|
||||
weights = self.weights_proj(x)
|
||||
|
||||
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3)
|
||||
return topk_indices
|
||||
180
vllm_npu/attention/utils.py
Normal file
180
vllm_npu/attention/utils.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens_cpu: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
||||
(such as GDN)."""
|
||||
|
||||
num_computed_tokens_cpu: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
max_query_len: int
|
||||
"""Max token number of request in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
"""decode token number per request"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
is_only_prefill: bool = False
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
# num_input_tokens refers to total number of tokens including
|
||||
# padding tokens. It is used to handle some padding operations.
|
||||
num_input_tokens: int = 0
|
||||
|
||||
# NOTE: This is a temporary solution for rotary embedding in MLA
|
||||
cos: torch.Tensor = None
|
||||
sin: torch.Tensor = None
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
decode_threshold: The maximum query length to be considered a decode.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
num_prefills: The number of prefill requests.
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||
|
||||
|
||||
def round_up(val: int, align: int) -> int:
|
||||
if align == 0:
|
||||
return 0
|
||||
return -(val // -align) * align
|
||||
|
||||
|
||||
def trans_rope_weight(weight, rope_dim):
|
||||
if rope_dim == 0:
|
||||
return weight.contiguous()
|
||||
nope_part = weight[..., :-rope_dim, :]
|
||||
rope_part = weight[..., -rope_dim:, :]
|
||||
reordered_rope_part = torch.cat(
|
||||
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
|
||||
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
|
||||
|
||||
|
||||
def transdata(nd_mat, block_size: tuple = (16, 16)):
|
||||
r = round_up(nd_mat.shape[0], block_size[0])
|
||||
c = round_up(nd_mat.shape[1], block_size[1])
|
||||
r_pad = r - nd_mat.shape[0]
|
||||
c_pad = c - nd_mat.shape[1]
|
||||
nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad))
|
||||
nz_mat = torch.permute(
|
||||
torch.reshape(
|
||||
nd_mat,
|
||||
(r // block_size[0], block_size[0], c // block_size[1],
|
||||
block_size[1]),
|
||||
),
|
||||
[2, 0, 1, 3],
|
||||
)
|
||||
nz_mat = torch.reshape(
|
||||
nz_mat,
|
||||
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
|
||||
return nz_mat
|
||||
0
vllm_npu/compilation/__init__.py
Normal file
0
vllm_npu/compilation/__init__.py
Normal file
398
vllm_npu/compilation/acl_graph.py
Normal file
398
vllm_npu/compilation/acl_graph.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ACLGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
aclgraph: Optional[torch.npu.NPUGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for aclgraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class ACLGraphWrapper:
|
||||
"""Wraps a runnable to add acl graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for aclgraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform aclgraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
graph_pool: Any = None,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.graph_pool = graph_pool
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
|
||||
# need to initialize a ACLGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.aclgraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# aclgraphs for.
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"aclgraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
aclgraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without aclgraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_aclgraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = \
|
||||
ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_aclgraph_entries[batch_descriptor]
|
||||
|
||||
if entry.aclgraph is None:
|
||||
if self.aclgraph_options.debug_log_enable:
|
||||
# Since we capture aclgraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that aclgraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.aclgraph_options.gc_disable:
|
||||
# during every model forward for piecewise aclgraph
|
||||
# mode, we will capture many pieces of aclgraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the aclgraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
forward_context.capturing = True
|
||||
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's aclgraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.aclgraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise aclgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other acl graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.aclgraph = aclgraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during acl graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for aclgraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
logger.info_once("Replaying aclgraph")
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
|
||||
|
||||
def update_attn_params(update_stream,
|
||||
forward_context,
|
||||
runtime_shape,
|
||||
kv_transfer_config=None):
|
||||
graph_params = get_graph_params()
|
||||
|
||||
# NOTE(Angazenn): By moving the npu-stream context ahead,
|
||||
# (see https://github.com/vllm-project/vllm-ascend/pull/3985)
|
||||
# we can reduce host overhead introduced by stream initialization.
|
||||
# However, we find that this might cause potential accuracy problems
|
||||
# with pd-disaggreagation. Therefore, this optimization is only enabled
|
||||
# without pd-disaggreagation. We are working on to solve this problem
|
||||
# directly int the future.
|
||||
if kv_transfer_config is not None:
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||
# in torch_npu. On some cases, _npu_paged_attention requires different workspace
|
||||
# among various seq_lens. So additional get_workspace is added here
|
||||
# to avoid such bugs.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
else:
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
speculative_config):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
spec_attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
seq_lens_list, actual_seq_lengths, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[
|
||||
key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
runtime_shape // spec_multiple - len(seq_lens_list))
|
||||
actual_seq_lengths = [
|
||||
spec_multiple * (i + 1)
|
||||
for i in range(runtime_shape // spec_multiple)
|
||||
]
|
||||
else:
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
workspaces: dict[int, torch.Tensor]
|
||||
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
0
vllm_npu/core/__init__.py
Normal file
0
vllm_npu/core/__init__.py
Normal file
39
vllm_npu/core/recompute_schedule_config.py
Normal file
39
vllm_npu/core/recompute_schedule_config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Type, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
MAX_INT = 2147483647
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecomputeSchedulerConfig(SchedulerConfig):
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_npu.core.recompute_scheduler.RecomputeScheduler")
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig):
|
||||
scheduler_config = {
|
||||
field.name: getattr(vllm_scheduler_config, field.name)
|
||||
for field in fields(vllm_scheduler_config) if field.init
|
||||
}
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_npu.core.recompute_scheduler.RecomputeScheduler")
|
||||
return cls(**scheduler_config)
|
||||
1392
vllm_npu/core/recompute_scheduler.py
Normal file
1392
vllm_npu/core/recompute_scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
108
vllm_npu/core/schedule_config.py
Normal file
108
vllm_npu/core/schedule_config.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Type, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
MAX_INT = 2147483647
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSchedulerConfig(SchedulerConfig):
|
||||
enable_chunked_prefill: bool = False
|
||||
max_long_partial_prefills: int = 1
|
||||
long_prefill_token_threshold: int = MAX_INT
|
||||
policy: str = "fcfs"
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_npu.core.scheduler.AscendScheduler")
|
||||
enable_pd_transfer: bool = False
|
||||
decode_max_num_seqs: int = 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(
|
||||
cls,
|
||||
vllm_scheduler_config: SchedulerConfig,
|
||||
ascend_scheduler_config,
|
||||
):
|
||||
scheduler_config = {
|
||||
field.name: getattr(vllm_scheduler_config, field.name)
|
||||
for field in fields(vllm_scheduler_config) if field.init
|
||||
}
|
||||
# Override default values into original SchedulerConfig
|
||||
scheduler_config["enable_chunked_prefill"] = False
|
||||
scheduler_config["max_long_partial_prefills"] = None
|
||||
scheduler_config["long_prefill_token_threshold"] = None
|
||||
scheduler_config["policy"] = "fcfs"
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_npu.core.scheduler.AscendScheduler")
|
||||
scheduler_config["enable_pd_transfer"] = False
|
||||
scheduler_config["decode_max_num_seqs"] = 0
|
||||
# Override params in original SchedulerConfig with params in ascend_scheduler_config
|
||||
for k, _ in scheduler_config.items():
|
||||
if hasattr(ascend_scheduler_config, k):
|
||||
scheduler_config[k] = getattr(ascend_scheduler_config, k)
|
||||
return cls(**scheduler_config)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
|
||||
self.encoder_cache_size = self.max_num_batched_tokens
|
||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||
if (self.max_num_batched_tokens < self.max_model_len
|
||||
and not self.chunked_prefill_enabled):
|
||||
raise ValueError(
|
||||
"Ascend scheduler is enabled without chunked prefill feature. "
|
||||
f"Argument max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({self.max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
# concurrent partial prefills. Default is 1 meaning not enabled.
|
||||
if self.max_long_partial_prefills is None:
|
||||
self.max_long_partial_prefills = 1
|
||||
self.long_prefill_token_threshold = MAX_INT
|
||||
|
||||
if self.long_prefill_token_threshold is None or \
|
||||
self.long_prefill_token_threshold <= 0:
|
||||
if self.max_model_len is None:
|
||||
self.long_prefill_token_threshold = MAX_INT
|
||||
else:
|
||||
self.long_prefill_token_threshold = \
|
||||
max(1, int(self.max_model_len * 0.04))
|
||||
|
||||
if self.max_long_partial_prefills < 0:
|
||||
raise ValueError(
|
||||
f"max_long_partial_prefills must be non-negative, but got "
|
||||
f"{self.max_long_partial_prefills}")
|
||||
if self.long_prefill_token_threshold < 0:
|
||||
raise ValueError(
|
||||
f"long_prefill_token_threshold must be non-negative, but got "
|
||||
f"{self.long_prefill_token_threshold}")
|
||||
|
||||
if self.policy != "fcfs":
|
||||
raise NotImplementedError(
|
||||
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
||||
)
|
||||
if self.send_delta_data:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support send_delta_data.")
|
||||
if getattr(self, "scheduler_delay_factor", 0) > 0:
|
||||
raise NotImplementedError(
|
||||
"currently AscendScheduler doesn't support scheduler_delay_factor."
|
||||
)
|
||||
587
vllm_npu/core/scheduler.py
Normal file
587
vllm_npu/core/scheduler.py
Normal file
@@ -0,0 +1,587 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Iterable, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
|
||||
class AscendScheduler(Scheduler):
|
||||
"""This Scheduler extends vllm's original v1 scheduler
|
||||
with prefill-first scheduling strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vllm_config, kv_cache_config,
|
||||
structured_output_manager, mm_registry,
|
||||
include_finished_set, log_stats)
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
self.finished_prefill_reqs: deque[Request] = deque()
|
||||
enable_pd_transfer = getattr(self.scheduler_config,
|
||||
'enable_pd_transfer', False)
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.phase = "" if not enable_pd_transfer else "prefill"
|
||||
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
|
||||
decode_max_num_seqs)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
scheduled_loras: set[int] = set()
|
||||
|
||||
# Use a temporary deque to collect requests that need to be skipped
|
||||
# and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests: deque[Request] = deque()
|
||||
|
||||
if self.phase == "prefill":
|
||||
remaining_running_reqs = []
|
||||
for request in self.running:
|
||||
# move request has finished prefill to finished_prefill_reqs
|
||||
if request.num_tokens > request.num_prompt_tokens:
|
||||
self.finished_prefill_reqs.append(request)
|
||||
else:
|
||||
remaining_running_reqs.append(request)
|
||||
self.running = remaining_running_reqs
|
||||
# all request prefilled, change phase to decode
|
||||
if not self.waiting and not self.running:
|
||||
self.phase = "decode"
|
||||
# Skip long prompt requests in prefill stage.
|
||||
# long_prefill_budget is float('inf') if not use.
|
||||
if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0:
|
||||
long_prefill_budget = float('inf')
|
||||
long_prefill_token_threshold = float('inf')
|
||||
else:
|
||||
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
|
||||
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold
|
||||
|
||||
# Schedule prefill requests first.
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == (self.decode_max_num_running_reqs
|
||||
if self.phase == "decode" else
|
||||
self.max_num_running_reqs):
|
||||
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
|
||||
def skip_cur_request():
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
|
||||
# P/D: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
new_computed_blocks, num_new_local_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
num_external_computed_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
else:
|
||||
# P/D: skip checking prefix cache if loaded from remote kvs.
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
|
||||
# P/D: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
blocks = None
|
||||
# Number of tokens to be scheduled.
|
||||
else:
|
||||
prompt_limit = self._get_prompt_limit(request)
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
||||
self.block_size)
|
||||
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
||||
|
||||
# Finish request that exceeds prompt_limit or kv cache size.
|
||||
if num_new_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) is too long"
|
||||
" and exceeds limit of %d",
|
||||
num_new_tokens,
|
||||
prompt_limit,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_IGNORED
|
||||
self.finished_req_ids.add( # type: ignore
|
||||
request.request_id) # type: ignore
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
if num_new_tokens > token_budget:
|
||||
# Scheduling would exceed token_budget, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
assert num_new_tokens > 0
|
||||
blocks = new_computed_blocks.blocks[0]
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
if num_new_tokens == 0 or len(
|
||||
encoder_inputs_to_schedule) == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||
blocks, watermark):
|
||||
# Scheduling would exceed watermark, skip.
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
if num_new_tokens > long_prefill_token_threshold \
|
||||
and long_prefill_budget <= 0:
|
||||
skip_cur_request()
|
||||
continue
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
# KVConnector: update internal state after allocation.
|
||||
# This information is used to determine if a load is
|
||||
# needed for this request.
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
self.waiting.popleft()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
# Check request status.
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
|
||||
req_to_new_blocks[
|
||||
request.request_id] = self.kv_cache_manager.get_blocks(
|
||||
request.request_id)
|
||||
# Update request info.
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
if num_new_tokens > long_prefill_token_threshold:
|
||||
long_prefill_budget -= 1
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.extendleft(skipped_waiting_requests)
|
||||
|
||||
if self.phase == "decode":
|
||||
while len(
|
||||
self.running
|
||||
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
|
||||
request = self.finished_prefill_reqs.popleft()
|
||||
self.running.append(request)
|
||||
|
||||
# If no prefill requests are scheduled,
|
||||
# Schedule decode requests next.
|
||||
if len(self.scheduled_req_ids) == 0:
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
if request.request_id in self.scheduled_req_ids:
|
||||
# This request has already been scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec -
|
||||
request.num_computed_tokens)
|
||||
assert (request.num_tokens - request.num_computed_tokens) == 1
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_budget = encoder_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_budget)
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id
|
||||
not in scheduled_loras):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
num_new_tokens = 0
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reason:
|
||||
# 1. No new tokens to schedule. This may happen when PP>1 and
|
||||
# we have already scheduled all prompt tokens but they are
|
||||
# not finished yet.
|
||||
# 2. Adding the request exceeds the max_loras constraint.
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
preempted_req = self.running.pop()
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED,
|
||||
scheduled_timestamp)
|
||||
self.waiting.appendleft(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
# The request can be scheduled.
|
||||
can_schedule = True
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
assert new_blocks is not None
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Record scheduled LoRA requests.
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(
|
||||
self.running
|
||||
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs) <= len(self.running)
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks)
|
||||
scheduled_cached_reqs = cached_reqs_data
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids, # type: ignore
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.
|
||||
get_freed_mm_hashes(),
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
# 1. Plan the KV cache store
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
events = self.kv_cache_manager.take_events()
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
# Advance the number of computed tokens for the request AFTER
|
||||
# the request is scheduled.
|
||||
# 1. The scheduler_output of the current step has to include the
|
||||
# original number of scheduled tokens to determine input IDs.
|
||||
# 2. Advance the number of computed tokens here allowing us to
|
||||
# schedule the prefill request again immediately in the next
|
||||
# scheduling step.
|
||||
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
||||
# computed tokens will be adjusted in update_from_output.
|
||||
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
||||
|
||||
self.finished_req_ids = set() # type: ignore
|
||||
return scheduler_output
|
||||
|
||||
def _check_watermark_for_prefill(self,
|
||||
request,
|
||||
num_new_tokens,
|
||||
computed_blocks,
|
||||
watermark=0.01):
|
||||
computed_blocks = computed_blocks or []
|
||||
watermark_blocks = self.kv_cache_config.num_blocks * watermark
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
len(computed_blocks) * self.block_size)
|
||||
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
|
||||
self.block_size)
|
||||
req_blocks = self.kv_cache_manager.coordinator.get_blocks(
|
||||
request.request_id)
|
||||
num_new_blocks = (num_required_blocks - len(req_blocks[0]) -
|
||||
len(computed_blocks))
|
||||
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
||||
if blk.ref_cnt == 0)
|
||||
# If number of free blocks is less than water mark after allocating, don't allocate.
|
||||
if (self.kv_cache_manager.block_pool.get_num_free_blocks() -
|
||||
num_evictable_computed_blocks -
|
||||
num_new_blocks) < watermark_blocks:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_prompt_limit(self, request: Request) -> int:
|
||||
if (self.scheduler_config.chunked_prefill_enabled
|
||||
and not self.scheduler_config.is_multi_step):
|
||||
prompt_limit = self.scheduler_config.max_model_len
|
||||
else:
|
||||
prompt_limit = min(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Model is fine tuned with long context. Return the fine tuned max_len.
|
||||
if request.lora_request and request.lora_request.long_lora_max_len:
|
||||
assert prompt_limit <= request.lora_request.long_lora_max_len
|
||||
return request.lora_request.long_lora_max_len
|
||||
else:
|
||||
return prompt_limit
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: RequestStatus,
|
||||
) -> None:
|
||||
"""Handles the finish signal from outside the scheduler.
|
||||
|
||||
For example, the API server can abort a request when the client
|
||||
disconnects.
|
||||
"""
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.scheduled_req_ids.discard(request.request_id)
|
||||
super().finish_requests(request_ids, finished_status)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> EngineCoreOutputs:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
# loop can be a performance bottleneck. We should do our best to avoid
|
||||
# expensive operations inside the loop.
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
continue
|
||||
if req_id in self.scheduled_req_ids:
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
|
||||
return super().update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
330
vllm_npu/cpu_binding.py
Normal file
330
vllm_npu/cpu_binding.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
|
||||
ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
|
||||
CPU_BINDING_NUM = os.getenv("CPU_BINDING_NUM")
|
||||
|
||||
|
||||
def execute_command(cmd_list):
|
||||
with subprocess.Popen(cmd_list,
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE) as p:
|
||||
out, err = p.communicate(timeout=1000)
|
||||
res = out.decode()
|
||||
return res
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceInfo:
|
||||
"""
|
||||
Parse a single line of device information into structured data.
|
||||
"""
|
||||
_info_line: str = ""
|
||||
npu_id: int = 0
|
||||
chip_id: int = 0
|
||||
chip_logic_id: Union[int, str] = 0
|
||||
chip_name: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
npu_id_str, chip_id_str, chip_logic_id_str, self.chip_name = self._info_line.strip(
|
||||
).split(None, 3)
|
||||
self.npu_id = int(npu_id_str)
|
||||
self.chip_id = int(chip_id_str)
|
||||
if chip_logic_id_str.isnumeric():
|
||||
self.chip_logic_id = int(chip_logic_id_str)
|
||||
|
||||
|
||||
class NpuHbmInfo:
|
||||
visible_npu_ids: Optional[List[int]] = None
|
||||
hbm_capacity: Optional[int] = None
|
||||
hbm_usage: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def set_visible_devices(cls, world_size):
|
||||
"""
|
||||
Determine which NPUs are visible to the current process and cache their
|
||||
logical NPU IDs in `cls.visible_npu_ids`.
|
||||
"""
|
||||
if cls.visible_npu_ids:
|
||||
return
|
||||
if ASCEND_RT_VISIBLE_DEVICES is None:
|
||||
devices = sorted(list(_get_device_map_info().keys()))
|
||||
else:
|
||||
devices_str = ASCEND_RT_VISIBLE_DEVICES
|
||||
devices = [int(x) for x in devices_str.split(",")]
|
||||
device_map_info = _get_device_map_info()
|
||||
npu_ids = []
|
||||
for device in devices:
|
||||
device_info = device_map_info.get(device)
|
||||
if device_info is None:
|
||||
raise RuntimeError(
|
||||
f"Device {device} not found in device_map_info")
|
||||
npu_ids.append(device_info.npu_id)
|
||||
cls.visible_npu_ids = npu_ids
|
||||
|
||||
@classmethod
|
||||
def get_hbm_capacity(cls, rank, world_size, need_nz):
|
||||
"""
|
||||
Query and cache the HBM (or DDR) capacity in **bytes** for the NPU assigned
|
||||
to the current process.
|
||||
"""
|
||||
soc_version = torch_npu._C._npu_get_soc_version()
|
||||
if cls.hbm_capacity:
|
||||
return cls.hbm_capacity
|
||||
if not cls.visible_npu_ids:
|
||||
cls.set_visible_devices(world_size)
|
||||
assert cls.visible_npu_ids is not None
|
||||
npu_id = cls.visible_npu_ids[rank]
|
||||
memory_info = execute_command(
|
||||
["npu-smi", "info", "-i", f"{npu_id}", "-t",
|
||||
"memory"]).split("\n")[1:]
|
||||
if soc_version == 240:
|
||||
hbm_capacity_key = 'Capacity(MB)'
|
||||
elif not need_nz:
|
||||
hbm_capacity_key = 'HBM Capacity(MB)'
|
||||
else:
|
||||
hbm_capacity_key = 'DDR Capacity(MB)'
|
||||
for line in memory_info:
|
||||
try:
|
||||
key, value = line.strip().split(':', 2)
|
||||
if key.strip() == hbm_capacity_key:
|
||||
cls.hbm_capacity = int(value.strip()) * 1024 * 1024
|
||||
return cls.hbm_capacity
|
||||
except ValueError:
|
||||
pass
|
||||
raise ValueError('not found valid hbm capactiy')
|
||||
|
||||
@classmethod
|
||||
def get_hbm_usage(cls, rank, world_size, need_nz):
|
||||
"""
|
||||
Return the current HBM or DDR usage
|
||||
ratio (0-1) for the NPU assigned to the given rank.
|
||||
"""
|
||||
if cls.hbm_usage:
|
||||
return cls.hbm_usage
|
||||
if not cls.visible_npu_ids:
|
||||
cls.set_visible_devices(world_size)
|
||||
assert cls.visible_npu_ids is not None
|
||||
npu_id = cls.visible_npu_ids[rank]
|
||||
usage_info = execute_command(
|
||||
["npu-smi", "info", "-i", f"{npu_id}", "-t",
|
||||
"usages"]).split("\n")[1:]
|
||||
soc_version = torch_npu._C._npu_get_soc_version()
|
||||
if soc_version == 240:
|
||||
hbm_capacity_key = 'Memory Usage Rate(%)'
|
||||
elif not need_nz:
|
||||
hbm_capacity_key = 'HBM Usage Rate(%)'
|
||||
else:
|
||||
hbm_capacity_key = 'DDR Usage Rate(%)'
|
||||
for line in usage_info:
|
||||
try:
|
||||
key, value = line.strip().split(':', 2)
|
||||
if key.strip() == hbm_capacity_key:
|
||||
hbm_usage = (float(value.strip()) + 1) / 100
|
||||
return hbm_usage
|
||||
except ValueError:
|
||||
pass
|
||||
raise ValueError('not found valid hbm usage')
|
||||
|
||||
|
||||
def _get_device_map_info() -> Dict[int, DeviceInfo]:
|
||||
"""
|
||||
Build and return a mapping from logical chip ID (int) to its DeviceInfo object.
|
||||
"""
|
||||
device_map_info = {}
|
||||
device_map = execute_command(["npu-smi", "info",
|
||||
"-m"]).strip().split("\n")[1:]
|
||||
for line in device_map:
|
||||
device_info = DeviceInfo(line.strip())
|
||||
if isinstance(device_info.chip_logic_id, int):
|
||||
device_map_info[device_info.chip_logic_id] = device_info
|
||||
return device_map_info
|
||||
|
||||
|
||||
def _get_pcie_info(devices: List[int], keyword="PCIeBusInfo"):
|
||||
"""
|
||||
Query each NPU in the given device list and return a mapping
|
||||
from logical device ID to its PCIe bus address.
|
||||
"""
|
||||
device_map_info = _get_device_map_info()
|
||||
device_pcie_tbl = {}
|
||||
for device in devices:
|
||||
device_info = device_map_info.get(device)
|
||||
if not device_info:
|
||||
raise RuntimeError(
|
||||
"Can not get device info, you can use BIND_CPU=0 to skip.")
|
||||
pcie_info = execute_command([
|
||||
"npu-smi", "info", "-t", "board", "-i", f"{device_info.npu_id}",
|
||||
"-c", f"{device_info.chip_id}"
|
||||
]).strip().split("\n")
|
||||
for _ in pcie_info:
|
||||
line = ''.join(_.split())
|
||||
if line.startswith(keyword):
|
||||
device_pcie_tbl[device] = line[len(keyword) + 1:]
|
||||
break
|
||||
|
||||
return device_pcie_tbl
|
||||
|
||||
|
||||
def _get_numa_info(pcie_tbl, keyword="NUMAnode"):
|
||||
"""
|
||||
Build two mappings: device → NUMA node, and NUMA node → [devices].
|
||||
"""
|
||||
device_numa_tbl: Dict[int, int] = {} # device id -> numa id
|
||||
numa_devices_tbl: Dict[int, List[int]] = {} # numa id -> device ids
|
||||
|
||||
for device, pcie_no in pcie_tbl.items():
|
||||
numa_info = execute_command(["lspci", "-s", f"{pcie_no}",
|
||||
"-vvv"]).split("\n")
|
||||
for _ in numa_info:
|
||||
line = ''.join(_.split())
|
||||
if line.startswith(keyword):
|
||||
numa_id = int(line[len(keyword) + 1:])
|
||||
device_numa_tbl[device] = numa_id
|
||||
|
||||
devices = numa_devices_tbl.get(numa_id, None)
|
||||
if devices is None:
|
||||
numa_devices_tbl[numa_id] = list()
|
||||
|
||||
numa_devices_tbl[numa_id].append(device)
|
||||
break
|
||||
|
||||
return device_numa_tbl, numa_devices_tbl
|
||||
|
||||
|
||||
def _get_numa_info_v2(
|
||||
devices: List[int],
|
||||
keyword="NUMAnode(s)") -> Tuple[Dict[int, int], Dict[int, List[int]]]:
|
||||
"""
|
||||
Evenly distribute the given device list across all NUMA nodes and return
|
||||
both device-to-numa and numa-to-devices mappings.
|
||||
"""
|
||||
numa_nodes = 1
|
||||
numa_info = execute_command(["lscpu"]).split("\n")
|
||||
for _ in numa_info:
|
||||
line = ''.join(_.split())
|
||||
if keyword not in line:
|
||||
continue
|
||||
numa_nodes = int(line[-1])
|
||||
break
|
||||
|
||||
device_per_numa, tail_device = divmod(len(devices), numa_nodes)
|
||||
device_count_per_numa_list = [
|
||||
device_per_numa + (i < tail_device) for i in range(numa_nodes)
|
||||
]
|
||||
|
||||
ends = list(accumulate(device_count_per_numa_list))
|
||||
starts = [0] + ends[:-1]
|
||||
|
||||
numa_devices_tbl = {
|
||||
ind: devices[start:end]
|
||||
for ind, (start, end) in enumerate(zip(starts, ends))
|
||||
}
|
||||
|
||||
device_numa_tbl = {
|
||||
device: numa
|
||||
for numa, _devices in numa_devices_tbl.items()
|
||||
for device in _devices
|
||||
}
|
||||
|
||||
return device_numa_tbl, numa_devices_tbl
|
||||
|
||||
|
||||
def _get_cpu_info(numa_ids, keyword1="NUMAnode", keyword2="CPU(s)"):
|
||||
"""
|
||||
Parse lscpu output to build a dict that maps each NUMA
|
||||
node ID to the list of CPU core IDs belonging to it.
|
||||
"""
|
||||
cpu_idx_tbl = dict()
|
||||
numa_keywords = [keyword1 + str(idx) + keyword2 for idx in numa_ids]
|
||||
cpu_info = execute_command(["lscpu"]).split("\n")
|
||||
for _ in cpu_info:
|
||||
line = ''.join(_.split())
|
||||
if any(line.startswith(word) for word in numa_keywords):
|
||||
split_info = line.split(":")
|
||||
cpu_id_ranges = split_info[-1].split(",")
|
||||
|
||||
ranges = list()
|
||||
for range_str in cpu_id_ranges:
|
||||
endpoints = range_str.split("-")
|
||||
if len(endpoints) != 2:
|
||||
raise Exception(
|
||||
"lscpu command output error, please check !")
|
||||
|
||||
ranges += [
|
||||
cid for cid in range(int(endpoints[0]),
|
||||
int(endpoints[1]) + 1)
|
||||
]
|
||||
|
||||
numa_id = int(split_info[0].replace(keyword1,
|
||||
'').replace(keyword2, ''))
|
||||
cpu_idx_tbl[numa_id] = ranges
|
||||
return cpu_idx_tbl
|
||||
|
||||
|
||||
def bind_cpus(rank_id, ratio=0.5):
|
||||
# get all visible devices
|
||||
visible_devices = ASCEND_RT_VISIBLE_DEVICES
|
||||
|
||||
if visible_devices is None:
|
||||
devices = sorted(list(_get_device_map_info().keys()))
|
||||
else:
|
||||
devices = [int(x) for x in visible_devices.split(",")]
|
||||
|
||||
# Query the NUMA affinity of each NPU via its PCIe address; if this fails,
|
||||
# fall back to evenly distributing the devices across NUMA nodes.
|
||||
device_pcie_tbl = _get_pcie_info(devices)
|
||||
device_numa_tbl, numa_devices_tbl = _get_numa_info(device_pcie_tbl)
|
||||
if not device_numa_tbl or not numa_devices_tbl:
|
||||
device_numa_tbl, numa_devices_tbl = _get_numa_info_v2(devices)
|
||||
|
||||
# Obtain the complete list of CPU cores for each NUMA node.
|
||||
cpu_idx_tbl = _get_cpu_info(list(numa_devices_tbl.keys()))
|
||||
|
||||
cur_device = devices[rank_id]
|
||||
numa_id = device_numa_tbl.get(cur_device)
|
||||
|
||||
# Within the NUMA node, evenly partition the CPU cores
|
||||
# among all NPUs (or use the amount specified by CPU_BINDING_NUM)
|
||||
shard_devices = numa_devices_tbl.get(numa_id)
|
||||
shard_devices.sort()
|
||||
|
||||
all_cpus = cpu_idx_tbl.get(numa_id)
|
||||
logger.info(
|
||||
f"rank_id: {rank_id}, device_id: {cur_device}, "
|
||||
f"numa_id: {numa_id}, shard_devices: {shard_devices}, cpus: {all_cpus}"
|
||||
)
|
||||
|
||||
cpu_nums = len(all_cpus)
|
||||
if CPU_BINDING_NUM is None:
|
||||
cpu_num_per_device = int(cpu_nums * ratio // len(shard_devices))
|
||||
else:
|
||||
cpu_num_per_device = int(CPU_BINDING_NUM)
|
||||
if len(shard_devices) * cpu_num_per_device > cpu_nums:
|
||||
raise RuntimeError(
|
||||
f"Cpu num in numa {numa_id} to assign {cpu_num_per_device} for every device is not enough, "
|
||||
f"please decrease the value of CPU_BINDING_NUM!")
|
||||
if cpu_num_per_device < 0:
|
||||
raise ValueError("CPU_BINDING_NUM should not be less than 0.")
|
||||
|
||||
idx = shard_devices.index(cur_device)
|
||||
binding_cpus = [
|
||||
all_cpus[_] for _ in range(idx * cpu_num_per_device, (idx + 1) *
|
||||
cpu_num_per_device)
|
||||
]
|
||||
|
||||
# cpu bind
|
||||
p = psutil.Process()
|
||||
p.cpu_affinity(binding_cpus)
|
||||
new_affinity = p.cpu_affinity()
|
||||
logger.info(
|
||||
f"process {p.pid}, new_affinity is {new_affinity}, cpu count {cpu_num_per_device}"
|
||||
)
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
CUDA-to-NPU Compatibility Layer.
|
||||
|
||||
Monkey-patches ``torch.cuda`` APIs so that code written for CUDA
|
||||
(e.g. ``torch.cuda.Stream()``, ``torch.cuda.Event()``) transparently
|
||||
delegates to the corresponding ``torch.npu`` equivalents. This
|
||||
allows vLLM's ``GPUModelRunner`` to run on Ascend NPU without
|
||||
source modifications.
|
||||
"""
|
||||
|
||||
import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _patch_cuda_to_npu() -> None:
|
||||
"""Apply monkey-patches: redirect torch.cuda → torch.npu."""
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stream / Event
|
||||
# ------------------------------------------------------------------
|
||||
torch.cuda.Stream = torch.npu.Stream # type: ignore[attr-defined]
|
||||
torch.cuda.Event = torch.npu.Event # type: ignore[attr-defined]
|
||||
torch.cuda.current_stream = torch.npu.current_stream # type: ignore
|
||||
torch.cuda.default_stream = torch.npu.default_stream # type: ignore
|
||||
|
||||
# torch.cuda.stream() context manager
|
||||
torch.cuda.stream = torch.npu.stream # type: ignore[attr-defined]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Device management
|
||||
# ------------------------------------------------------------------
|
||||
torch.cuda.set_device = torch.npu.set_device # type: ignore
|
||||
torch.cuda.synchronize = torch.npu.synchronize # type: ignore
|
||||
torch.cuda.device_count = torch.npu.device_count # type: ignore
|
||||
torch.cuda.current_device = torch.npu.current_device # type: ignore
|
||||
torch.cuda.is_available = lambda: True # type: ignore
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Memory management
|
||||
# ------------------------------------------------------------------
|
||||
torch.cuda.empty_cache = torch.npu.empty_cache # type: ignore
|
||||
torch.cuda.mem_get_info = torch.npu.mem_get_info # type: ignore
|
||||
torch.cuda.memory_allocated = torch.npu.memory_allocated # type: ignore
|
||||
torch.cuda.max_memory_allocated = torch.npu.max_memory_allocated # type: ignore
|
||||
torch.cuda.memory_reserved = torch.npu.memory_reserved # type: ignore
|
||||
torch.cuda.max_memory_reserved = torch.npu.max_memory_reserved # type: ignore
|
||||
torch.cuda.reset_peak_memory_stats = torch.npu.reset_peak_memory_stats # type: ignore
|
||||
torch.cuda.memory_stats = torch.npu.memory_stats # type: ignore
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Device properties
|
||||
# ------------------------------------------------------------------
|
||||
_real_npu_props = torch.npu.get_device_properties
|
||||
|
||||
def _get_device_properties(device=None):
|
||||
"""Return NPU device properties with CUDA-compatible attributes."""
|
||||
props = _real_npu_props(device)
|
||||
# GPUModelRunner accesses .multi_processor_count which may not
|
||||
# exist on NPU. Provide a sensible fallback.
|
||||
if not hasattr(props, "multi_processor_count"):
|
||||
props.multi_processor_count = 1 # type: ignore[attr-defined]
|
||||
if not hasattr(props, "major"):
|
||||
props.major = 9 # type: ignore[attr-defined]
|
||||
props.minor = 0 # type: ignore[attr-defined]
|
||||
return props
|
||||
|
||||
torch.cuda.get_device_properties = _get_device_properties # type: ignore
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Misc
|
||||
# ------------------------------------------------------------------
|
||||
if not hasattr(torch.cuda, "_get_device_index"):
|
||||
torch.cuda._get_device_index = torch.npu._get_device_index # type: ignore
|
||||
|
||||
# graph / CUDAGraph stubs (NPU does not support CUDA graphs)
|
||||
if not hasattr(torch.cuda, "CUDAGraph") or True:
|
||||
torch.cuda.CUDAGraph = MagicMock # type: ignore[attr-defined]
|
||||
|
||||
if not hasattr(torch.cuda, "graph"):
|
||||
|
||||
def _noop_graph(*args, **kwargs):
|
||||
"""No-op context manager for CUDA graphs on NPU."""
|
||||
import contextlib
|
||||
return contextlib.nullcontext()
|
||||
|
||||
torch.cuda.graph = _noop_graph # type: ignore[attr-defined]
|
||||
0
vllm_npu/device_allocator/__init__.py
Normal file
0
vllm_npu/device_allocator/__init__.py
Normal file
278
vllm_npu/device_allocator/camem.py
Normal file
278
vllm_npu/device_allocator/camem.py
Normal file
@@ -0,0 +1,278 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# CANN-mem-based pytorch pluggable allocator to implement sleep mode.
|
||||
#
|
||||
import dataclasses
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from acl.rt import memcpy # type: ignore # noqa: F401
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.platform import NPUPlatform
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found_line = None
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found_line = line
|
||||
break
|
||||
if found_line is None:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = found_line.index("/")
|
||||
path = found_line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
camem_available = False
|
||||
try:
|
||||
from vllm_npu.vllm_npu_C import ( # type: ignore # noqa: F401
|
||||
init_module, python_create_and_map, python_unmap_and_release)
|
||||
lib_name = find_loaded_library("vllm_npu_C")
|
||||
camem_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"Failed to import vllm_npu_C:%s. Sleep mode will be disabled. ", e)
|
||||
init_module = None
|
||||
python_create_and_map = None
|
||||
python_unmap_and_release = None
|
||||
lib_name = None
|
||||
libcudart = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = Tuple[int, int, int, int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocationData:
|
||||
handle: HandleType
|
||||
tag: str
|
||||
cpu_backup_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def create_and_map(allocation_handle: HandleType) -> None:
|
||||
python_create_and_map(*allocation_handle)
|
||||
|
||||
|
||||
def unmap_and_release(allocation_handle: HandleType) -> None:
|
||||
python_unmap_and_release(*allocation_handle)
|
||||
|
||||
|
||||
def get_pluggable_allocator(
|
||||
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
|
||||
python_free_func: Callable[[int], tuple[int, int, int, int]]
|
||||
) -> torch.npu.memory.NPUPluggableAllocator:
|
||||
init_module(python_malloc_fn, python_free_func)
|
||||
new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc',
|
||||
'my_free')
|
||||
return new_alloc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool_with_allocator(
|
||||
python_malloc_fn: Callable[[tuple[int, int, int, int]], None],
|
||||
python_free_func: Callable[[int], tuple[int, int, int, int]]):
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.npu.memory.MemPool(new_alloc._allocator)
|
||||
with torch.npu.memory.use_mem_pool(mem_pool):
|
||||
yield mem_pool, new_alloc
|
||||
|
||||
|
||||
class CaMemAllocator:
|
||||
"""
|
||||
A singleton class that manages a memory pool for CANN tensors.
|
||||
The memory in this pool can be offloaded or discarded when the
|
||||
allocator sleeps.
|
||||
Inside the `use_memory_pool(tag)` context, all tensors created will
|
||||
be allocated in the memory pool, and has the same tag as the
|
||||
tag passed to the context.
|
||||
When we call `sleep`, all tensors with the specified tag will be
|
||||
offloaded to CPU memory, and the rest of the tensors will be discarded.
|
||||
When we call `wake_up`, all tensors that are previously offloaded
|
||||
will be loaded back to GPU memory, and the rest of the tensors will
|
||||
have empty memory.
|
||||
Why it needs to be a singleton?
|
||||
When allocated tensors are garbage collected, PyTorch will call
|
||||
the free callback, which will call the `python_free_callback` method.
|
||||
The C-extension uses a global variable to store the function of an
|
||||
instance of this class. If we create multiple instances of this class,
|
||||
the global variable will be overwritten and the free callback will
|
||||
not work as expected.
|
||||
"""
|
||||
instance = None
|
||||
default_tag: str = "default"
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> "CaMemAllocator":
|
||||
"""
|
||||
CaMemAllocator is a singleton class.
|
||||
We cannot call the constructor directly.
|
||||
Call this method to get the instance.
|
||||
"""
|
||||
if CaMemAllocator.instance is None:
|
||||
CaMemAllocator.instance = CaMemAllocator()
|
||||
return CaMemAllocator.instance
|
||||
|
||||
def __init__(self):
|
||||
conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "")
|
||||
assert "expandable_segments:True" not in conf, \
|
||||
("Expandable segments are not compatible with memory pool. "
|
||||
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
||||
"for the latest updates.")
|
||||
|
||||
self.pointer_to_data: Dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CaMemAllocator.default_tag
|
||||
self.allocator_and_pools: Dict[str, Any] = {}
|
||||
|
||||
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
Internal method to store the allocation data
|
||||
when memory is allocated in the memory pool."""
|
||||
py_d_mem = allocation_handle[2]
|
||||
self.pointer_to_data[py_d_mem] = AllocationData(
|
||||
allocation_handle, self.current_tag)
|
||||
return
|
||||
|
||||
def python_free_callback(self, ptr: int) -> HandleType:
|
||||
"""
|
||||
Internal method to look up the allocation data
|
||||
when memory is freed in the memory pool."""
|
||||
data = self.pointer_to_data.pop(ptr)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
data.cpu_backup_tensor = None
|
||||
return data.handle
|
||||
|
||||
def sleep(
|
||||
self,
|
||||
offload_tags: Optional[Union[Tuple[str, ...],
|
||||
str]] = None) -> None:
|
||||
"""
|
||||
Put the allocator in sleep mode.
|
||||
All data in the memory allocation with the specified tag will be
|
||||
offloaded to CPU memory, and others will be discarded.
|
||||
:param offload_tags: The tags of the memory allocation that will be
|
||||
offloaded. The rest of the memory allocation will be discarded.
|
||||
"""
|
||||
if offload_tags is None:
|
||||
# by default, allocated tensors are offloaded
|
||||
# when the allocator sleeps
|
||||
offload_tags = (CaMemAllocator.default_tag, )
|
||||
elif isinstance(offload_tags, str):
|
||||
offload_tags = (offload_tags, )
|
||||
|
||||
assert isinstance(offload_tags, tuple)
|
||||
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
if data.tag in offload_tags:
|
||||
size_in_bytes = handle[1]
|
||||
cpu_backup_tensor = torch.empty(
|
||||
size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device='cpu',
|
||||
pin_memory=NPUPlatform.is_pin_memory_available())
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
ACL_MEMCPY_DEVICE_TO_HOST = 2
|
||||
dest_max = cpu_ptr + size_in_bytes * 2
|
||||
memcpy(cpu_ptr, dest_max, ptr, size_in_bytes,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST)
|
||||
data.cpu_backup_tensor = cpu_backup_tensor
|
||||
unmap_and_release(handle)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Wake up the allocator from sleep mode.
|
||||
All data that is previously offloaded will be loaded back to GPU
|
||||
memory, and the rest of the data will have empty memory."""
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
if tags is None or data.tag in tags:
|
||||
handle = data.handle
|
||||
create_and_map(handle)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
cpu_backup_tensor = data.cpu_backup_tensor
|
||||
if cpu_backup_tensor is not None:
|
||||
size_in_bytes = cpu_backup_tensor.numel(
|
||||
) * cpu_backup_tensor.element_size()
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
ACL_MEMCPY_HOST_TO_DEVICE = 1
|
||||
dest_max = ptr + size_in_bytes * 2
|
||||
memcpy(ptr, dest_max, cpu_ptr, size_in_bytes,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE)
|
||||
data.cpu_backup_tensor = None
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool(self, tag: Optional[str] = None):
|
||||
"""
|
||||
A context manager to use the memory pool.
|
||||
All memory allocation created inside the context will be allocated
|
||||
in the memory pool, and has the specified tag.
|
||||
:param tag: The tag of the memory allocation. If None, the default tag
|
||||
will be used.
|
||||
"""
|
||||
if tag is None:
|
||||
tag = CaMemAllocator.default_tag
|
||||
|
||||
assert isinstance(tag, str)
|
||||
|
||||
old_tag = self.current_tag
|
||||
self.current_tag = tag
|
||||
with use_memory_pool_with_allocator(self.python_malloc_callback,
|
||||
self.python_free_callback) as data:
|
||||
# start to hit another PyTorch bug in PyTorch 2.6,
|
||||
# possibly because of gc-related issue w.r.t. the allocator and
|
||||
# the memory pool.
|
||||
# to avoid the issue, we keep a reference of the data.
|
||||
# see https://github.com/pytorch/pytorch/issues/146431 .
|
||||
self.allocator_and_pools[tag] = data
|
||||
yield
|
||||
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
||||
# when using pluggable allocator, see
|
||||
# https://github.com/pytorch/pytorch/issues/145168 .
|
||||
# if we have some memory allocated and then freed,
|
||||
# the memory will not be released.
|
||||
# right now it is fine, because we only use this allocator
|
||||
# during weight loading and kv cache creation, where we only
|
||||
# allocate memory.
|
||||
# TODO: we need to find a way to release the memory,
|
||||
# i.e. calling torch.cuda.empty_cache()
|
||||
self.current_tag = old_tag
|
||||
|
||||
def get_current_usage(self) -> int:
|
||||
"""
|
||||
Get the total number of bytes allocated in the memory pool.
|
||||
"""
|
||||
sum_bytes: int = 0
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
sum_bytes += handle[1]
|
||||
return sum_bytes
|
||||
@@ -1 +1,40 @@
|
||||
"""Ascend NPU distributed communication (HCCL)."""
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
|
||||
def register_connector():
|
||||
KVConnectorFactory.register_connector(
|
||||
"LLMDataDistCMgrConnector",
|
||||
"vllm_npu.distributed.llmdatadist_c_mgr_connector",
|
||||
"LLMDataDistCMgrConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorV1", "vllm_npu.distributed.mooncake_connector",
|
||||
"MooncakeConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorStoreV1",
|
||||
"vllm_npu.distributed.mooncake.mooncake_store_connector_v1",
|
||||
"MooncakeConnectorV1")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeLayerwiseConnector",
|
||||
"vllm_npu.distributed.mooncake_layerwise_connector",
|
||||
"MooncakeLayerwiseConnector")
|
||||
|
||||
@@ -1,42 +1,46 @@
|
||||
"""
|
||||
NPUCommunicator — HCCL-based device communicator for Ascend NPU.
|
||||
|
||||
Extends ``DeviceCommunicatorBase`` with NPU-specific collective
|
||||
operations using the HCCL backend.
|
||||
"""
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase,
|
||||
)
|
||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
DeviceCommunicatorBase
|
||||
|
||||
|
||||
class NPUCommunicator(DeviceCommunicatorBase):
|
||||
"""Device communicator for Ascend NPU using HCCL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
def __init__(self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
import torch_npu # noqa: F401
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
def all_to_all(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""All-to-all communication for NPU tensors."""
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if scatter_dim < 0:
|
||||
scatter_dim += input_.dim()
|
||||
if gather_dim < 0:
|
||||
@@ -53,22 +57,17 @@ class NPUCommunicator(DeviceCommunicatorBase):
|
||||
tensor_shape = list(tensor_shape_base)
|
||||
tensor_shape[gather_dim] = gather_sizes[i]
|
||||
output_list.append(
|
||||
torch.empty(
|
||||
tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device,
|
||||
)
|
||||
)
|
||||
torch.empty(tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device))
|
||||
|
||||
else:
|
||||
input_list = [
|
||||
t.contiguous()
|
||||
for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim
|
||||
)
|
||||
t.contiguous() for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim)
|
||||
]
|
||||
output_list = [
|
||||
torch.empty_like(input_list[i])
|
||||
for i in range(self.world_size)
|
||||
torch.empty_like(input_list[i]) for i in range(self.world_size)
|
||||
]
|
||||
|
||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||
|
||||
471
vllm_npu/distributed/cpu_offload_connector.py
Normal file
471
vllm_npu/distributed/cpu_offload_connector.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MLAAttentionSpec)
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.distributed.cpu_offload_manager.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
num_scheduled_tokens: int
|
||||
num_computed_tokens: int
|
||||
num_gpu_computed_tokens: int
|
||||
num_cpu_computed_tokens: int
|
||||
|
||||
def update(self, other: "ReqMeta"):
|
||||
self.gpu_block_ids.extend(other.gpu_block_ids)
|
||||
self.cpu_block_ids.extend(other.cpu_block_ids)
|
||||
self.num_scheduled_tokens = other.num_scheduled_tokens
|
||||
self.num_computed_tokens = other.num_computed_tokens
|
||||
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
|
||||
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
requests: dict[str, ReqMeta]
|
||||
finished_req_ids: set[str]
|
||||
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[
|
||||
CPUOffloadingConnectorWorker] = None
|
||||
elif role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||
vllm_config)
|
||||
self.connector_worker = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(connector_metadata,
|
||||
CPUOffloadingConnectorMetadata)
|
||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.clear_connector_metadata()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.start_load_kv()
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(), None
|
||||
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.update_state_after_alloc(request)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.build_connector_meta(
|
||||
scheduler_output)
|
||||
return KVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request",
|
||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return True, None
|
||||
|
||||
|
||||
class CPUOffloadingConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorScheduler")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
self.num_gpu_computed_tokens: dict[str, int] = {}
|
||||
self.num_cpu_computed_tokens: dict[str, int] = {}
|
||||
self.allocated_req_ids: set[str] = set()
|
||||
self.finished_req_ids: list[str] = []
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.zmq_rpc_client.call("post_init")
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"swap_in_threshold", 0)
|
||||
else:
|
||||
self.swap_in_threshold = 0
|
||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, ori_request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||
"get_matched_num_and_touch", request)
|
||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||
self.num_cpu_computed_tokens[
|
||||
request.request_id] = num_cpu_computed_tokens
|
||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||
else:
|
||||
return 0, load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request"):
|
||||
self.allocated_req_ids.add(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
num_tokens = {}
|
||||
# process scheduled_new_reqs
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
num_tokens[req_id] = (
|
||||
req.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
# process scheduled_cached_reqs
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_tokens[req_id] = (
|
||||
cached_reqs.num_computed_tokens[idx] +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||
self.allocated_req_ids -
|
||||
scheduler_output.num_scheduled_tokens.keys())
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||
num_tokens,
|
||||
unallocated_req_ids)
|
||||
metadata = CPUOffloadingConnectorMetadata(
|
||||
requests={},
|
||||
finished_req_ids=set(self.finished_req_ids),
|
||||
)
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
gpu_block_ids = req.block_ids[0]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=req.num_computed_tokens,
|
||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||
self.num_gpu_computed_tokens.clear()
|
||||
self.num_cpu_computed_tokens.clear()
|
||||
self.allocated_req_ids.clear()
|
||||
self.finished_req_ids.clear()
|
||||
return metadata
|
||||
|
||||
def request_finished(self, ori_request: "Request"):
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
self.finished_req_ids.append(request.request_id)
|
||||
# inform metadata server to record request, and free it after finish sending
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||
request)
|
||||
|
||||
|
||||
class CPUOffloadingConnectorWorker:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorWorker")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.pp_rank = get_pp_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_rank = self.tp_group.rank_in_group
|
||||
self.tp_world_size = self.tp_group.world_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.load_stream = torch.npu.Stream()
|
||||
self.save_stream = torch.npu.Stream()
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.load_block_mapping: list[tuple[int, int]] = []
|
||||
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
|
||||
self.save_output_queue: queue.Queue[str] = queue.Queue()
|
||||
self.save_thread = threading.Thread(target=self._save_listener)
|
||||
self.save_thread.start()
|
||||
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
|
||||
# all dp shared the same metadata server, only start the process on data_rank 0
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
|
||||
config = VllmConfig()
|
||||
config.cache_config = vllm_config.cache_config
|
||||
config.parallel_config = vllm_config.parallel_config
|
||||
config.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.init_metadata_server(config)
|
||||
self._wait_for_metadata_process_start()
|
||||
|
||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||
self.metadata_thread = threading.Thread(
|
||||
target=MetadataServerProc.run_metadata_server,
|
||||
args=(vllm_config, ),
|
||||
)
|
||||
self.metadata_thread.daemon = True
|
||||
self.metadata_thread.start()
|
||||
|
||||
def _wait_for_metadata_process_start(self):
|
||||
# TODO: wait for metadata server to start, add a rpc to check if ready
|
||||
while True:
|
||||
try:
|
||||
if self.zmq_rpc_client.call("ready"):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.info(f"wait for metadata server to start, error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
for req_id, req in connector_metadata.requests.items():
|
||||
if req_id in self.requests:
|
||||
self.requests[req_id].update(req)
|
||||
req = self.requests[req_id]
|
||||
else:
|
||||
self.requests[req_id] = req
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||
req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append(
|
||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for req_id in connector_metadata.finished_req_ids:
|
||||
if req_id in self.requests:
|
||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self.load_block_mapping.clear()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||
self.gpu_kv_caches = kv_caches
|
||||
model_config = self.vllm_config.model_config
|
||||
mla_config: Optional[MLAConfig] = None
|
||||
if model_config.use_mla:
|
||||
mla_config = MLAConfig(
|
||||
model_config.hf_text_config.kv_lora_rank,
|
||||
model_config.hf_text_config.qk_rope_head_dim)
|
||||
self.cpu_kv_caches = list(
|
||||
self.zmq_rpc_client.call(
|
||||
"init_cpu_kv_caches",
|
||||
self.pp_rank,
|
||||
self.tp_rank,
|
||||
get_kv_cache_spec(self.vllm_config),
|
||||
mla_config,
|
||||
).values())
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
self.current_layer = 0
|
||||
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
|
||||
self.load_kv_layer(0)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
|
||||
self.load_stream.synchronize()
|
||||
self.current_layer += 1
|
||||
self.load_kv_layer(self.current_layer)
|
||||
|
||||
def load_kv_layer(self, layer: int):
|
||||
if layer == len(self.gpu_kv_caches):
|
||||
return
|
||||
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
|
||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||
with torch.npu.stream(self.load_stream):
|
||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||
for gpu_layer_part, cpu_layer_part in zip(
|
||||
gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(
|
||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
|
||||
def get_finished(self) -> set[str]:
|
||||
done_sending: set[str] = set()
|
||||
while True:
|
||||
try:
|
||||
id = self.save_output_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
done_sending.add(id)
|
||||
for id in done_sending:
|
||||
del self.requests[id]
|
||||
if self.tp_world_size == 1:
|
||||
return done_sending
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self.done_sending_count[req_id] += 1
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.tp_world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
self.done_sending_count[req_id] += 1
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self.done_sending_count.keys()):
|
||||
if self.done_sending_count[req_id] == self.tp_world_size:
|
||||
del self.done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
# release cpu_kv_cache after request sending finished
|
||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||
sending_finished_thread = threading.Thread(
|
||||
target=self._sending_finished, args=(all_done_sending, ))
|
||||
sending_finished_thread.daemon = True
|
||||
sending_finished_thread.start()
|
||||
|
||||
return all_done_sending
|
||||
else:
|
||||
self.tp_group.send_object(done_sending, dst=0)
|
||||
return done_sending
|
||||
|
||||
def _sending_finished(self, all_done_sending):
|
||||
for req_id in all_done_sending:
|
||||
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
|
||||
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
|
||||
|
||||
def _save_listener(self):
|
||||
save_block_mapping = []
|
||||
while True:
|
||||
req_id, req = self.save_input_queue.get()
|
||||
for i in range(
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||
self.block_size, len(req.cpu_block_ids))):
|
||||
save_block_mapping.append(
|
||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
with torch.npu.stream(self.save_stream):
|
||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||
if self.use_mla:
|
||||
start, step = self.tp_rank, self.tp_world_size
|
||||
else:
|
||||
start, step = 0, 1
|
||||
for i in range(start, len(save_block_mapping), step):
|
||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(
|
||||
cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(
|
||||
gpu_layer_part[gpu_block_id],
|
||||
non_blocking=True)
|
||||
self.save_stream.synchronize()
|
||||
self.save_output_queue.put(req_id)
|
||||
save_block_mapping.clear()
|
||||
|
||||
|
||||
# Copied from vllm_npu/worker/model_runner_v1.py.
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
forward_ctx = vllm_config.compilation_config.static_forward_context
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
ascend_config = get_ascend_config()
|
||||
use_sfa = ascend_config.use_sfa
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if use_mla and not use_sfa:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
||||
else:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype)
|
||||
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
return kv_cache_spec
|
||||
202
vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py
Normal file
202
vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from vllm.utils import logger, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
PrefixCachingMetrics)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||
get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CPUCacheStats:
|
||||
|
||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.log_stats = log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
|
||||
self.time_sec = int(time.time())
|
||||
|
||||
def log(self):
|
||||
current_time_sec = int(time.time())
|
||||
# Log the prefix cache hit rate every 10 seconds.
|
||||
if current_time_sec - self.time_sec >= 10:
|
||||
self.time_sec = current_time_sec
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def update(self, num_tokens, num_computed_tokens):
|
||||
# Note the function is called by scheduler
|
||||
if self.log_stats and self.enable_prefix_caching:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
|
||||
def set_cache_stats(self, num_tokens, num_computed_tokens):
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.hits = num_computed_tokens
|
||||
self.prefix_cache_stats.queries = num_tokens
|
||||
self.prefix_cache_stats.requests = 1
|
||||
|
||||
|
||||
class CPUKVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_cpu_blocks: int,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
# Record kv block hashes, avoid redundant computation.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
# Record blocks touched in get_matched_num_and_touch().
|
||||
self.req_to_computed_blocks: defaultdict[
|
||||
str, list[KVCacheBlock]] = defaultdict(list)
|
||||
# Record the request that failed to allocate.
|
||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||
log_stats=True)
|
||||
# Record request that will be free after finish sending
|
||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||
|
||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (request.sampling_params.prompt_logprobs is not None):
|
||||
return 0, False
|
||||
request_id = request.request_id
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = request.block_hashes
|
||||
self.req_to_block_hashes[request_id] = block_hashes
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.single_type_manager.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
num_computed_tokens = len(computed_blocks[0]) * self.block_size
|
||||
self.req_to_computed_blocks[request_id] = computed_blocks[0]
|
||||
# We should touch these blocks in the concurrent scenarios.
|
||||
self.block_pool.touch(computed_blocks)
|
||||
|
||||
# cup prefix cache status set and log
|
||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||
num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||
self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.log()
|
||||
|
||||
return num_computed_tokens, False
|
||||
|
||||
def _release_ahead_touch(self, request_id: str):
|
||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
if computed_blocks:
|
||||
self.single_type_manager.block_pool.free_blocks(
|
||||
reversed(computed_blocks))
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
for request_id in unallocated_req_ids:
|
||||
self._free_slots(request_id)
|
||||
req_to_new_blocks = {}
|
||||
for request_id, num_tokens in req_to_num_tokens.items():
|
||||
if self.req_failed_to_allocate[request_id]:
|
||||
continue
|
||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
))
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
self._release_ahead_touch(request_id)
|
||||
self.req_failed_to_allocate[request_id] = True
|
||||
continue
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request_id, new_computed_blocks)
|
||||
# Allocate new blocks but do not cache now.
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request_id, num_tokens)
|
||||
self.req_to_num_tokens[request_id] = num_tokens
|
||||
# No need to release ref_cnt because we use officially.
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
req_to_new_blocks[request_id] = [
|
||||
block.block_id for block in new_computed_blocks + new_blocks
|
||||
]
|
||||
return req_to_new_blocks
|
||||
|
||||
def record_request_cache_and_free_slots(self, request: Request):
|
||||
logger.debug(
|
||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
self.req_to_free[request.request_id] = request
|
||||
|
||||
def cache_and_free_slots(self, request_id: str):
|
||||
logger.debug(
|
||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
if request_id not in self.req_to_free:
|
||||
logger.Error(
|
||||
f"request {request_id} not in req_to_free, maybe bug!")
|
||||
return
|
||||
request = self.req_to_free[request_id]
|
||||
if not self.req_failed_to_allocate[request_id]:
|
||||
self.single_type_manager.cache_blocks(
|
||||
request,
|
||||
self.req_to_num_tokens[request_id],
|
||||
)
|
||||
self._free_slots(request_id)
|
||||
logger.debug(
|
||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
del self.req_to_free[request_id]
|
||||
|
||||
def _free_slots(self, request_id: str):
|
||||
# This function is designed to be reentrant.
|
||||
self._release_ahead_touch(request_id)
|
||||
self.single_type_manager.free(request_id)
|
||||
self.req_to_block_hashes.pop(request_id, None)
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
self.req_failed_to_allocate.pop(request_id, None)
|
||||
self.req_to_num_tokens.pop(request_id, None)
|
||||
269
vllm_npu/distributed/cpu_offload_manager/metadata.py
Normal file
269
vllm_npu/distributed/cpu_offload_manager/metadata.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.utils import get_dtype_size, logger, make_zmq_socket
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_npu.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
|
||||
CPUKVCacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAConfig:
|
||||
nope_dim: int
|
||||
rope_dim: int
|
||||
|
||||
|
||||
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
return None
|
||||
|
||||
|
||||
class MetadataServer:
|
||||
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
|
||||
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||
|
||||
class ZMQRPCClient:
|
||||
|
||||
def __init__(self, identity=f"worker-{os.getpid()}"):
|
||||
logger.info(f"metadata client for worker {identity} started")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.DEALER, # type: ignore
|
||||
bind=False,
|
||||
identity=identity.encode(),
|
||||
linger=0)
|
||||
|
||||
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||
request = (func_name, args, kwargs)
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(request))
|
||||
_ = self.socket.recv()
|
||||
response = pickle.loads(self.socket.recv())
|
||||
result, error = response
|
||||
if error:
|
||||
logger.exception(f"call metadata sever error: {error}")
|
||||
raise error
|
||||
if func_name == "init_cpu_kv_caches":
|
||||
(memory_dict, layer_size, layer_dtype, mla_config) = result
|
||||
# shared_memory_dict is recorded in self to close
|
||||
self.shared_memory_dict = memory_dict
|
||||
result = {}
|
||||
for key, shm in memory_dict.items():
|
||||
tensor = torch.frombuffer(
|
||||
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||
if mla_config is not None:
|
||||
tensor = tensor.split(
|
||||
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||
result[key] = tensor
|
||||
return result
|
||||
|
||||
def __del__(self):
|
||||
# will be finalized by outer process
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
if hasattr(self, 'shared_memory_dict'):
|
||||
for shm in self.shared_memory_dict.values():
|
||||
shm.close()
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||
assert kv_transfer_config is not None
|
||||
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
||||
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.ROUTER, # type: ignore
|
||||
bind=True,
|
||||
linger=0)
|
||||
self.functions: dict[str, Callable] = {
|
||||
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||
"post_init": self.post_init,
|
||||
"ready": self.ready,
|
||||
}
|
||||
self.shared_memory = {} # type: ignore
|
||||
self.num_cpu_blocks = -1
|
||||
|
||||
@staticmethod
|
||||
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
|
||||
try:
|
||||
existing_shm = SharedMemory(name=name, create=False)
|
||||
existing_shm.close()
|
||||
existing_shm.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return SharedMemory(name=name, create=True, size=size)
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def init_cpu_kv_caches(
|
||||
self,
|
||||
pp_rank: int,
|
||||
tp_rank: int,
|
||||
kv_cache_specs: dict[str, AttentionSpec],
|
||||
mla_config: MLAConfig,
|
||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
||||
MLAConfig]:
|
||||
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||
# follow the assumption that each layer has the same spec
|
||||
layer = next(iter(kv_cache_specs.values()))
|
||||
assert all([
|
||||
layer.page_size_bytes == any.page_size_bytes
|
||||
for any in kv_cache_specs.values()
|
||||
])
|
||||
# mla shares the same kv cache among different tp
|
||||
if layer.use_mla:
|
||||
tp_rank = 0
|
||||
if (pp_rank, tp_rank) in self.shared_memory:
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
available_memory = self.available_memory
|
||||
shared_memory_dict = {}
|
||||
if layer.use_mla:
|
||||
available_memory //= self.pipeline_parallel_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
else:
|
||||
available_memory //= self.world_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||
for layer_name in kv_cache_specs.keys():
|
||||
# only this format can share during ZeroMQ+pickle
|
||||
shared_memory_dict[
|
||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||
if layer.use_mla:
|
||||
assert mla_config is not None
|
||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, mla_config)
|
||||
else:
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, None)
|
||||
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||
self.num_cpu_blocks = num_blocks
|
||||
self.layer = layer
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
|
||||
def post_init(self):
|
||||
# different processors in data parallel may call multiple times
|
||||
if hasattr(self, 'cpu_block_manager'):
|
||||
return
|
||||
# do shared_memory() at least once
|
||||
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||
assert self.num_cpu_blocks >= 0
|
||||
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
||||
self.num_cpu_blocks)
|
||||
self.functions.update({
|
||||
"get_matched_num_and_touch":
|
||||
self.cpu_block_manager.get_matched_num_and_touch,
|
||||
"allocate_slots":
|
||||
self.cpu_block_manager.allocate_slots,
|
||||
"record_request_cache_and_free_slots":
|
||||
self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||
"cache_and_free_slots":
|
||||
self.cpu_block_manager.cache_and_free_slots,
|
||||
})
|
||||
|
||||
def serve_step(self):
|
||||
client_id = self.socket.recv()
|
||||
_ = self.socket.recv()
|
||||
raw_msg = self.socket.recv()
|
||||
try:
|
||||
func_name, args, kwargs = pickle.loads(raw_msg)
|
||||
except Exception as e:
|
||||
response = (None, Exception(f"Invalid request: {str(e)}"))
|
||||
else:
|
||||
if func_name in self.functions:
|
||||
try:
|
||||
result = self.functions[func_name](*args, **kwargs)
|
||||
response = (result, None) # type: ignore
|
||||
except Exception as e:
|
||||
logger.exception(f"metadata execute error: {e}")
|
||||
response = (None, e) # type: ignore
|
||||
else:
|
||||
response = (None, NameError(f"Function {func_name} not found"))
|
||||
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(response))
|
||||
|
||||
def shutdown(self):
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
||||
"ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
for cached in self.shared_memory.values():
|
||||
for shm in cached[0].values():
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
class MetadataServerProc:
|
||||
|
||||
@staticmethod
|
||||
def run_metadata_server(vllm_config: VllmConfig):
|
||||
if (not vllm_config.cache_config.enable_prefix_caching
|
||||
or get_cpu_offload_connector(vllm_config) is None):
|
||||
return
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||
# signal.signal(signal.SIGINT, _signal_handler)
|
||||
metadata_server: Optional[MetadataServer] = None
|
||||
try:
|
||||
metadata_server = MetadataServer(vllm_config)
|
||||
logger.info("Metadata server started.")
|
||||
while True:
|
||||
metadata_server.serve_step()
|
||||
except SystemExit:
|
||||
logger.info("Metadata server exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Metadata server error: {e}.")
|
||||
raise e
|
||||
finally:
|
||||
if metadata_server is not None:
|
||||
metadata_server.shutdown()
|
||||
165
vllm_npu/distributed/device_communicators/pyhccl.py
Normal file
165
vllm_npu/distributed/device_communicators/pyhccl.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
from vllm_npu.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyHcclCommunicator to. If None,
|
||||
it will be bind to f"npu:{local_rank}".
|
||||
library_path: the path to the HCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
try:
|
||||
self.hccl = HCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing HCCL library
|
||||
# e.g. in a non-NPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using pyhccl")
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"npu:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from HCCL
|
||||
with torch.npu.device(device):
|
||||
self.unique_id = self.hccl.hcclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = hcclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
|
||||
# hccl communicator and stream will use this device
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
||||
253
vllm_npu/distributed/device_communicators/pyhccl_wrapper.py
Normal file
253
vllm_npu/distributed/device_communicators/pyhccl_wrapper.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.utils import find_hccl_library
|
||||
|
||||
# export types and functions from hccl to Python ===
|
||||
# for the original hccl definition, please check
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
|
||||
|
||||
hcclResult_t = ctypes.c_int
|
||||
hcclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class hcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 4108)]
|
||||
|
||||
|
||||
aclrtStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
hcclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclDataTypeEnum:
|
||||
hcclInt8 = 0
|
||||
hcclInt16 = 1
|
||||
hcclInt32 = 2
|
||||
hcclFloat16 = 3
|
||||
hcclFloat32 = 4
|
||||
hcclInt64 = 5
|
||||
hcclUint64 = 6
|
||||
hcclUint8 = 7
|
||||
hcclUint16 = 8
|
||||
hcclUint32 = 9
|
||||
hcclFloat64 = 10
|
||||
hcclBfloat16 = 11
|
||||
hcclInt128 = 12
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.hcclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.hcclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.hcclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.hcclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.hcclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.hcclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.hcclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.hcclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
hcclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclRedOpTypeEnum:
|
||||
hcclSum = 0
|
||||
hcclProd = 1
|
||||
hcclMax = 2
|
||||
hcclMin = 3
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.hcclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.hcclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.hcclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.hcclMin
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
HCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = HCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load HCCL library from %s. "
|
||||
"It is expected if you are not running on Ascend NPUs."
|
||||
"Otherwise, the hccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def hcclGetErrorString(self, result: hcclResult_t) -> str:
|
||||
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def HCCL_CHECK(self, result: hcclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.hcclGetErrorString(result)
|
||||
raise RuntimeError(f"HCCL error: {error_str}")
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HCCLLibrary",
|
||||
"hcclDataTypeEnum",
|
||||
"hcclRedOpTypeEnum",
|
||||
"hcclUniqueId",
|
||||
"hcclComm_t",
|
||||
"aclrtStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
994
vllm_npu/distributed/llmdatadist_c_mgr_connector.py
Normal file
994
vllm_npu/distributed/llmdatadist_c_mgr_connector.py
Normal file
@@ -0,0 +1,994 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||
LLMException, LLMRole)
|
||||
from vllm import envs
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import get_ip, logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_npu.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32
|
||||
}
|
||||
|
||||
|
||||
class LLMDataDistCMgrEvent(Enum):
|
||||
ReqForMetadata = 0
|
||||
ReqForFinished = 1
|
||||
|
||||
|
||||
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
|
||||
super_pod_id: str
|
||||
server_id: str
|
||||
device_id: str
|
||||
device_ip: str
|
||||
super_device_id: str
|
||||
cluster_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: str
|
||||
engine_id: str
|
||||
remote_tp_size: str
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
|
||||
def add_new_req(self, request_id: str, local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any]):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_tp_size=kv_transfer_params["remote_tp_size"],
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: Optional[
|
||||
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
|
||||
vllm_config, self.engine_id)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(
|
||||
self,
|
||||
kv_caches: dict[
|
||||
str, # type: ignore[override]
|
||||
Tuple[torch.Tensor]]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
LLMDataDistCMgrConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata, **kwargs) -> None:
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorScheduler():
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
self.local_ip = get_ip()
|
||||
# Can not retrieve the parallel config since it is not initialized.
|
||||
self.local_dp_rank = None
|
||||
self.tp_size = None
|
||||
if vllm_config.parallel_config.data_parallel_external_lb:
|
||||
dp_rank_local = vllm_config.parallel_config.data_parallel_rank
|
||||
else:
|
||||
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.port = dp_rank_local * tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT
|
||||
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
assert num_computed_tokens % self.block_size == 0
|
||||
# Note: We use the full token count as transmit data here.
|
||||
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
||||
return count, count > 0
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_externel_tokens: int):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
|
||||
)
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port", "remote_tp_size")):
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
else:
|
||||
logger.warning("" \
|
||||
f"Invalid KVTransferParams {params}, This request will be discard")
|
||||
else:
|
||||
assert num_externel_tokens == 0
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = LLMDataDistCMgrConnectorMetadata()
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params)
|
||||
|
||||
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s", request.status, params)
|
||||
|
||||
if (params is None or not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||
return False, None
|
||||
|
||||
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
|
||||
# we just transfer any data that computed from prefill node
|
||||
# note: there might be some issue on this, check it if there is any unexpected result
|
||||
computed_block_ids = block_ids
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
self._reqs_need_send[request.request_id] = time.perf_counter(
|
||||
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.local_ip,
|
||||
remote_port=self.port,
|
||||
remote_tp_size=str(
|
||||
self.vllm_config.parallel_config.tensor_parallel_size),
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorWorker():
|
||||
"""
|
||||
Implementation of Worker side methods
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
|
||||
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
|
||||
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
|
||||
self.local_rank_on_node = get_world_group().rank % (
|
||||
vllm_config.parallel_config.data_parallel_size_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
self.local_rank = get_world_group().local_rank
|
||||
if vllm_config.parallel_config.data_parallel_external_lb:
|
||||
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
else:
|
||||
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.rank = get_world_group().rank
|
||||
self.local_ip = get_ip()
|
||||
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
|
||||
self.local_agent_metadata: Optional[
|
||||
LLMDataDistCMgrAgentMetadata] = None
|
||||
self.vllm_config = vllm_config
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
self.thread_lock = threading.Lock()
|
||||
|
||||
self.llm_datadist_role = None
|
||||
self.llm_datadist_remote_role = None
|
||||
if self.kv_transfer_config.kv_role == "kv_producer":
|
||||
self.llm_datadist_role = LLMRole.PROMPT
|
||||
self.llm_datadist_remote_role = LLMRole.DECODER
|
||||
elif self.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.llm_datadist_role = LLMRole.DECODER
|
||||
self.llm_datadist_remote_role = LLMRole.PROMPT
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
|
||||
)
|
||||
|
||||
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
|
||||
self.linked_cluster: dict[Any, Any] = {}
|
||||
self.prefill_device_list: list[tuple[int, int]] = []
|
||||
self.decode_device_list: list[tuple[int, int]] = []
|
||||
global_rank_table = self.read_offline_rank_table()
|
||||
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
|
||||
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
self.finished_reqs: set[str] = set()
|
||||
self.soc_info = get_ascend_soc_version()
|
||||
# Set hccl deterministic for model execute
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
set[int]] = defaultdict(set)
|
||||
self.reqs_to_send: dict[str, float] = {}
|
||||
|
||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||
assert self.local_agent_metadata is not None
|
||||
port = envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
|
||||
url = f"tcp://{envs_ascend.vllm_npu_LLMDD_RPC_IP}:{port}"
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_decoder = msgspec.msgpack.Decoder()
|
||||
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
|
||||
logger.debug(f"Start to listen to address: {url}")
|
||||
logger.debug(
|
||||
f"The local agent metadata have {len(msg_to_send)} bytes here")
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
|
||||
)
|
||||
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
|
||||
event.set()
|
||||
while True:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
event_msg, decode_msg = msg_decoder.decode(msg)
|
||||
event_msg = LLMDataDistCMgrEvent(event_msg)
|
||||
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
|
||||
if "cluster_id" in decode_msg:
|
||||
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
|
||||
)
|
||||
sock.send_multipart((identity, b"", msg_to_send))
|
||||
self.add_remote_agent(decode_msg)
|
||||
else:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
|
||||
)
|
||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||
finished_req_id = decode_msg[0]
|
||||
with self.thread_lock:
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
if finished_req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
del self.reqs_to_send[finished_req_id]
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||
)
|
||||
|
||||
def init_llm_datadist(self):
|
||||
assert self.local_agent_metadata is not None
|
||||
llm_config = LLMConfig()
|
||||
llm_config.device_id = self.local_rank
|
||||
llm_config.sync_kv_timeout = get_transfer_timeout_value()
|
||||
llm_config.enable_switch_role = True
|
||||
llm_config.enable_cache_manager = True
|
||||
llm_config.enable_remote_cache_accessible = True
|
||||
llm_config_options = llm_config.generate_options()
|
||||
self.llm_datadist.init(llm_config_options)
|
||||
self.cache_manager = self.llm_datadist.cache_manager
|
||||
logger.info(
|
||||
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
|
||||
)
|
||||
|
||||
def read_offline_rank_table(self):
|
||||
assert (
|
||||
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
|
||||
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
with open(rank_table_path, "r", encoding="utf-8") as f:
|
||||
global_rank_table = json.load(f)
|
||||
decode_device_list = global_rank_table["decode_device_list"]
|
||||
for decode_device in decode_device_list:
|
||||
server_id = decode_device["server_id"]
|
||||
device_id = decode_device["device_id"]
|
||||
self.decode_device_list.append((server_id, device_id))
|
||||
prefill_device_list = global_rank_table["prefill_device_list"]
|
||||
for prefill_device in prefill_device_list:
|
||||
server_id = prefill_device["server_id"]
|
||||
device_id = prefill_device["device_id"]
|
||||
self.prefill_device_list.append((server_id, device_id))
|
||||
|
||||
# global_rank_table = json.dumps(global_rank_table)
|
||||
return global_rank_table
|
||||
|
||||
@staticmethod
|
||||
def _get_visible_devices() -> Callable[[str], bool]:
|
||||
"""
|
||||
Return a test function that check if the given device ID is visible.
|
||||
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
|
||||
"""
|
||||
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
if not visible_devices:
|
||||
return lambda device_id: True
|
||||
visible_device_list = visible_devices.split(",")
|
||||
return lambda device_id: device_id in visible_device_list
|
||||
|
||||
def read_agent_metadata(self, global_rank_table):
|
||||
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
|
||||
devices_type_list = []
|
||||
agent_metadata = None
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
elif self.llm_datadist_role == LLMRole.DECODER:
|
||||
devices_type_list.append("decode_device_list")
|
||||
else:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
devices_type_list.append("decode_device_list")
|
||||
for device_type in devices_type_list:
|
||||
device_list = global_rank_table[device_type]
|
||||
device_list = [
|
||||
d for d in device_list if d.get("server_id") == self.local_ip
|
||||
and device_filter(d.get("device_id", ""))
|
||||
]
|
||||
if len(device_list) <= self.tp_rank:
|
||||
continue
|
||||
device_info = device_list[self.tp_rank]
|
||||
super_pod_id_ = device_info.get("super_pod_id", None)
|
||||
server_id_ = device_info["server_id"]
|
||||
device_id_ = device_info["device_id"]
|
||||
device_ip_ = device_info["device_ip"]
|
||||
super_device_id_ = device_info.get("super_device_id", None)
|
||||
cluster_id_ = int(device_info["cluster_id"])
|
||||
agent_metadata = LLMDataDistCMgrAgentMetadata(
|
||||
super_pod_id=super_pod_id_,
|
||||
server_id=server_id_,
|
||||
device_id=device_id_,
|
||||
device_ip=device_ip_,
|
||||
super_device_id=super_device_id_,
|
||||
cluster_id=cluster_id_,
|
||||
)
|
||||
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
|
||||
return agent_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
assert len(first_kv_cache_tuple) > 1
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
|
||||
self.block_len = math.prod(block_shape)
|
||||
self.cache_addr: list[int] = []
|
||||
alignment = 2 * 1024 * 1024
|
||||
if self.use_mla:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
self.cache = (cache_k_normed, cache_k_pe)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
elif self.use_sparse:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
cache_k_idx_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
k_idx = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
||||
1], cache_or_caches[2]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
||||
cache_k_idx_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_idx = CacheDesc(
|
||||
len(self.cache_addr[2]), [*k_idx.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=2)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
||||
cache_desc_k_idx)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
||||
cache_key_k_idx)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
cache_k_idx = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
||||
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
base_addr = cache.data_ptr()
|
||||
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
self.cache_addr.append(base_addr)
|
||||
# register paged kv cache into the llm_cache manager
|
||||
self.cache_desc = CacheDesc(
|
||||
len(self.cache_addr), [*cache.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
self.cache_key = BlocksCacheKey(
|
||||
cluster_id=int(self.local_agent_metadata.cluster_id))
|
||||
logger.info(
|
||||
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
|
||||
)
|
||||
try:
|
||||
self.cache = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc, self.cache_addr, self.cache_key)
|
||||
logger.info(
|
||||
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
self.ready_event = threading.Event()
|
||||
self.metadata_agent_listener_t = threading.Thread(
|
||||
target=self.listen_for_agent_metadata_req,
|
||||
args=(self.ready_event, ),
|
||||
daemon=True,
|
||||
name="metadata_agent_listener")
|
||||
self.metadata_agent_listener_t.start()
|
||||
self.ready_event.wait()
|
||||
|
||||
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
|
||||
futures = []
|
||||
for req_id, meta in metadata.requests.items():
|
||||
logger.debug(f"Start to transmit {req_id}")
|
||||
future = self.executor.submit(
|
||||
self._read_blocks,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_ip=meta.remote_host,
|
||||
remote_port=int(meta.remote_port),
|
||||
remote_engine_id=meta.engine_id,
|
||||
request_id=req_id,
|
||||
remote_tp_size=meta.remote_tp_size,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
def handle_exception(future):
|
||||
if future.exception():
|
||||
logger.error(f"KV transfer task failed: {future.exception()}")
|
||||
|
||||
for future in futures:
|
||||
future.add_done_callback(handle_exception)
|
||||
self.reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||
assert self.local_agent_metadata is not None
|
||||
remote_cluster_id = metadata.cluster_id
|
||||
if remote_cluster_id in self.linked_cluster:
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
|
||||
)
|
||||
return remote_cluster_id
|
||||
remote_super_pod_id = metadata.super_pod_id
|
||||
remote_server_id = metadata.server_id
|
||||
is_same_server = remote_server_id == self.local_agent_metadata.server_id
|
||||
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
prefill_metadata = self.local_agent_metadata
|
||||
decode_metadata = metadata
|
||||
else:
|
||||
prefill_metadata = metadata
|
||||
decode_metadata = self.local_agent_metadata
|
||||
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
|
||||
cluster_rank_info = {
|
||||
prefill_metadata.cluster_id: 0,
|
||||
decode_metadata.cluster_id: 1
|
||||
}
|
||||
rank_table = {}
|
||||
rank_table["version"] = "1.2"
|
||||
rank_table["server_count"] = "1" if is_same_server else "2"
|
||||
rank_table["status"] = "completed"
|
||||
|
||||
# generate server_list for rank table
|
||||
rank_table["server_list"] = [] # type: ignore[assignment]
|
||||
decode_server_device_info = None
|
||||
prefill_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", prefill_metadata.device_id
|
||||
), ("device_ip", prefill_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
prefill_metadata.super_device_id), ("rank_id", "0")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
prefill_metadata.server_id
|
||||
}
|
||||
if is_same_server:
|
||||
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
|
||||
{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
})
|
||||
else:
|
||||
decode_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
decode_metadata.server_id
|
||||
}
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
prefill_server_device_info)
|
||||
if decode_server_device_info is not None:
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
decode_server_device_info)
|
||||
|
||||
if self.soc_info == AscendSocVersion.A3:
|
||||
# generate super_pod_list for rank table
|
||||
super_pod_list = []
|
||||
prefill_super_pod_info = {
|
||||
"super_pod_id": prefill_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": prefill_metadata.server_id
|
||||
}],
|
||||
}
|
||||
if is_same_pod and not is_same_server:
|
||||
prefill_super_pod_info[
|
||||
"server_list"].append( # type: ignore[attr-defined]
|
||||
{"server_id": decode_metadata.server_id})
|
||||
super_pod_list.append(prefill_super_pod_info)
|
||||
if not is_same_pod:
|
||||
decode_super_pod_id = {
|
||||
"super_pod_id": decode_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": decode_metadata.server_id
|
||||
}],
|
||||
}
|
||||
super_pod_list.append(decode_super_pod_id)
|
||||
rank_table[
|
||||
"super_pod_list"] = super_pod_list # type: ignore[assignment]
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
|
||||
)
|
||||
logger.info(f"rank table \n{rank_table}")
|
||||
logger.info(f"comm name: {comm_name}")
|
||||
logger.info(f"cluster rank info: {cluster_rank_info}")
|
||||
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
|
||||
json.dumps(rank_table))
|
||||
while True:
|
||||
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
|
||||
if ret == llm_datadist.RegisterMemStatus.OK:
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
|
||||
)
|
||||
break
|
||||
elif ret == llm_datadist.RegisterMemStatus.FAILED:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
|
||||
)
|
||||
time.sleep(1)
|
||||
logger.info("Checking query_register_mem_status again")
|
||||
self.linked_cluster.update({remote_cluster_id: comm_id})
|
||||
logger.info(f"cached linked cluster: {self.linked_cluster}")
|
||||
logger.info(
|
||||
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
|
||||
)
|
||||
return remote_cluster_id
|
||||
|
||||
def remove_remote_agent(self, cluster_id: int):
|
||||
if cluster_id not in self.linked_cluster:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
|
||||
)
|
||||
comm_id = self.linked_cluster[cluster_id]
|
||||
try:
|
||||
self.llm_datadist.unlink(comm_id)
|
||||
self.linked_cluster.pop(cluster_id)
|
||||
except LLMException:
|
||||
logger.error(
|
||||
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully remove remote client with cluster id {cluster_id} !"
|
||||
)
|
||||
|
||||
def connect_to_remote_agent(self, host: str, port: int) -> int:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Querying metadata from url: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
logger.info("Try request remote metadata from socket......")
|
||||
sock.send(msg_send)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder()
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
|
||||
logger.info(f"recving metadata: {metadata}")
|
||||
cluster_id = self.add_remote_agent(metadata)
|
||||
return cluster_id
|
||||
|
||||
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
|
||||
for port in ports:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}"
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_ip: str,
|
||||
remote_port: int,
|
||||
remote_engine_id: str,
|
||||
request_id: str,
|
||||
remote_tp_size: str,
|
||||
):
|
||||
# if remote_ip not in self.linked_cluster:
|
||||
tp_offset = self.tp_rank % int(remote_tp_size)
|
||||
remote_cluster_id = self.connect_to_remote_agent(
|
||||
remote_ip, remote_port + tp_offset)
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
return
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
logger.info(f"remote cluster id is: {remote_cluster_id}")
|
||||
if self.use_mla:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
elif self.use_sparse:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
remote_cache_key_k_idx = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=2)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_idx,
|
||||
self.cache[2], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key,
|
||||
self.cache, # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
remote_ports = list(
|
||||
range(remote_port + self.tp_rank,
|
||||
remote_port + int(remote_tp_size), self.tp_size))
|
||||
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
|
||||
with self.thread_lock:
|
||||
self.finished_reqs.add(request_id)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requuests."""
|
||||
now = time.perf_counter()
|
||||
with self.thread_lock:
|
||||
while self.reqs_to_send:
|
||||
req_id, expires = next(iter(self.reqs_to_send.items()))
|
||||
if now < expires:
|
||||
break
|
||||
logger.warning(
|
||||
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
|
||||
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
||||
)
|
||||
if req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(req_id)
|
||||
del self.reqs_to_send[req_id]
|
||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||
self.finished_reqs.clear()
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
return req_ids_to_ret, None
|
||||
else:
|
||||
return None, req_ids_to_ret
|
||||
|
||||
|
||||
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any,
|
||||
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
|
||||
socket.bind(addr)
|
||||
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
|
||||
socket.connect(addr)
|
||||
else:
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
0
vllm_npu/distributed/mooncake/__init__.py
Normal file
0
vllm_npu/distributed/mooncake/__init__.py
Normal file
449
vllm_npu/distributed/mooncake/config_data.py
Normal file
449
vllm_npu/distributed/mooncake/config_data.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import array
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import cdiv, logger
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeEngineMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" world size when running under a distributed setting """
|
||||
world_size: int
|
||||
""" worker id when running under a distributed setting """
|
||||
worker_id: int
|
||||
""" the format of kv tensors """
|
||||
kv_dtype: torch.dtype
|
||||
""" the shape of kv tensors """
|
||||
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
|
||||
kv_shape: tuple[int, int, int, int, int]
|
||||
block_size: int = 128
|
||||
""" whether use MLA"""
|
||||
use_mla: bool = False
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class MooncakeEngineKey:
|
||||
model_name: str
|
||||
world_size: int
|
||||
worker_id: int
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}")
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerMooncakeEngineKey(
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
def to_dict(self):
|
||||
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
|
||||
return {
|
||||
"__type__": "CacheEngineKey",
|
||||
"model_name": self.model_name,
|
||||
"world_size": self.world_size,
|
||||
"worker_id": self.worker_id,
|
||||
"chunk_hash": self.chunk_hash,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return MooncakeEngineKey(
|
||||
model_name=d["model_name"],
|
||||
world_size=d["world_size"],
|
||||
worker_id=d["worker_id"],
|
||||
chunk_hash=d["chunk_hash"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerMooncakeEngineKey(MooncakeEngineKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: MooncakeEngineMetadata,
|
||||
):
|
||||
self.metadata = metadata
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return MooncakeEngineKey(
|
||||
self.metadata.model_name,
|
||||
self.metadata.world_size,
|
||||
self.metadata.worker_id,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def _hash(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
prefix_hash: str,
|
||||
) -> str:
|
||||
# TODO: change it to a more efficient hash function
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
|
||||
elif isinstance(tokens, list):
|
||||
tokens_bytes = array.array("I", tokens).tobytes()
|
||||
return hashlib.sha256(prefix_hash.encode("ascii") +
|
||||
tokens_bytes).hexdigest()
|
||||
|
||||
def _chunk_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
) -> Iterable[Union[torch.Tensor, List[int]]]:
|
||||
"""
|
||||
Chunk the tokens into chunks of size self.metadata.block_size.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
device: the target device after chunking
|
||||
|
||||
:return: a generator of chunks of tokens, each with
|
||||
shape [metadata.block_size]
|
||||
"""
|
||||
for i in range(0, len(tokens), self.metadata.block_size):
|
||||
yield tokens[i:i + self.metadata.block_size]
|
||||
|
||||
def _prefix_hash(
|
||||
self,
|
||||
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
|
||||
) -> Iterable[str]:
|
||||
prefix_hash = ''
|
||||
for token_chunk in token_chunks:
|
||||
prefix_hash = self._hash(token_chunk, prefix_hash)
|
||||
yield prefix_hash
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if mask is not None:
|
||||
num_falses = mask.numel() - mask.long().sum().item()
|
||||
else:
|
||||
num_falses = 0
|
||||
|
||||
if num_falses % self.metadata.block_size != 0:
|
||||
raise ValueError(
|
||||
"The number of Falses in the mask is not a multiple of the chunk size."
|
||||
)
|
||||
total_len = len(tokens)
|
||||
|
||||
token_chunks = self._chunk_tokens(tokens)
|
||||
prefix_hashes = self._prefix_hash(token_chunks)
|
||||
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(prefix_hashes):
|
||||
start_idx = chunk_id * self.metadata.block_size
|
||||
end_idx = min(start_idx + self.metadata.block_size, total_len)
|
||||
if start_idx < num_falses:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in mooncake
|
||||
mooncake_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveSpec:
|
||||
# Skip already saved tokens
|
||||
skip_leading_tokens: int
|
||||
# Whether the scheduler allow us to save the tokens
|
||||
can_save: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_ids: list[int]
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
|
||||
# list[list[int]]
|
||||
# Need to check the type of request.block_ids
|
||||
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_ids.extend(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
|
||||
block_ids: list[int]
|
||||
# # Slot mapping if exchange for block_id
|
||||
# slot_mapping: torch.Tensor
|
||||
# Skip save or not
|
||||
save_spec: Optional[SaveSpec] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_ids = tracker.token_ids
|
||||
input_token_len = len(input_token_ids)
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
skip_leading_tokens = tracker.num_saved_tokens
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
|
||||
|
||||
# Calculate the token ids and slot mappings for load and save
|
||||
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
|
||||
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.mooncake_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
save_spec=save_spec,
|
||||
load_spec=load_spec,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerMooncakeEngineKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
use_ascend_direct: bool
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size", 3355443200),
|
||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
293
vllm_npu/distributed/mooncake/kv_transfer.py
Normal file
293
vllm_npu/distributed/mooncake/kv_transfer.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
|
||||
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches_base_addr = local_kv_caches_base_addr
|
||||
self.block_len = block_len
|
||||
self.token_database = token_database
|
||||
self.block_size = block_size
|
||||
self.done_task_lock = threading.Lock()
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
current_event: Optional[torch.npu.Event] = None,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"tokens": tokens,
|
||||
"block_ids": block_ids,
|
||||
"mask": mask,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
"current_event": current_event,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
current_event = req_meta["current_event"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
if key_list:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(key, addr, size)
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
num_layers: int):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.put(key, addr, size)
|
||||
if req_meta.layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
639
vllm_npu/distributed/mooncake/mooncake_engine.py
Normal file
639
vllm_npu/distributed/mooncake/mooncake_engine.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# Standard
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import get_kv_cache_torch_dtype, logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
|
||||
MooncakeEngineMetadata)
|
||||
from vllm_npu.distributed.mooncake.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class MooncakeEngine:
|
||||
#The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_layerwize: bool,
|
||||
):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla):
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = parallel_config.rank
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"register_buffer", False)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.current_layer = 0
|
||||
# self.use_mla = first_kv_cache_tuple[0].size(
|
||||
# -1) != first_kv_cache_tuple[1].size(-1)
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_dtype = get_kv_cache_torch_dtype(
|
||||
vllm_config.cache_config.cache_dtype, model_config.dtype)
|
||||
self.hidden_dim_size = num_kv_head * head_size
|
||||
if self.use_mla:
|
||||
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
|
||||
else:
|
||||
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
|
||||
head_size)
|
||||
self.metadata = MooncakeEngineMetadata(
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
self.block_size,
|
||||
self.use_mla,
|
||||
)
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata)
|
||||
|
||||
self.m_store = Mooncakestore(parallel_config)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
self._register(base_addr, region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
self._register(base_addr, region_len)
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database, self.block_len,
|
||||
self.block_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _register(self, ptr, length):
|
||||
logger.debug(
|
||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
||||
try:
|
||||
self.m_store.register_buffer(ptr, length)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Mooncake memory registration failed. Error is: {e}")
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
self.current_layer = 0
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
tokens = request.token_ids
|
||||
req_id = request.req_id
|
||||
if (load_spec.mooncake_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.mooncake_cached_tokens
|
||||
== tokens.shape[0] - 1):
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
|
||||
else:
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
|
||||
masked_token_count = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
token_mask = torch.ones_like(tokens, dtype=torch.bool)
|
||||
token_mask[:masked_token_count] = False
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
else:
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list,
|
||||
blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, _ = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
for layerwise_retriever in self.layerwise_retrievers:
|
||||
ret_token_mask = next(layerwise_retriever)
|
||||
if self.current_layer == self.num_layers - 1:
|
||||
assert ret_token_mask is not None
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: MooncakeConnectorMetadata) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_ids,
|
||||
mask=store_mask,
|
||||
block_ids=request.block_ids,
|
||||
)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
next(layerwise_storer)
|
||||
except Exception:
|
||||
raise
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_ids,
|
||||
request.block_ids,
|
||||
store_mask,
|
||||
request.is_last_chunk,
|
||||
current_event,
|
||||
)
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the KV transfer which
|
||||
will be passed into the npu_transfer.
|
||||
|
||||
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_required_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_required_tokens = len(tokens)
|
||||
|
||||
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer)
|
||||
ret_mask[start:end] = True
|
||||
|
||||
if keys:
|
||||
# Transpose the keys into layer major format
|
||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
if not first_flag:
|
||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
first_flag = False
|
||||
yield None
|
||||
else:
|
||||
# If no cache are found, we still need to yield to avoid
|
||||
# `StopIteration`
|
||||
for layer_id in range(self.num_layers):
|
||||
yield None
|
||||
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {len(tokens)} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the storage backend which
|
||||
will be passed into the gpu_connector.
|
||||
|
||||
return: A generator that yields None. In the first iteration, the
|
||||
generator allocates the memory objects for all layers and moves
|
||||
the KV cache of the first layer from GPU to CPU. In the next
|
||||
iterations, it moves the KV cache of layer i from GPU to the memory
|
||||
objects (on CPU) and puts the memory objects of layer i-1 to the
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_stored_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_stored_tokens = len(tokens)
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
||||
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.load_async else set())
|
||||
|
||||
logger.debug(
|
||||
"Number of completed KV cache send requests: %d, receive "
|
||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def wait_layer_transfer_finish(self):
|
||||
time.sleep(10)
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
res = self.m_store.batch_exists(
|
||||
keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def lookup_scheduler(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, self.tp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@0", f"@{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
res = self.m_store.batch_exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) *
|
||||
num_block] # type: ignore[index]
|
||||
for i in range(self.tp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the cache engine and free all the resources"""
|
||||
self.m_store.close()
|
||||
126
vllm_npu/distributed/mooncake/mooncake_store.py
Normal file
126
vllm_npu/distributed/mooncake/mooncake_store.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Standard
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
from mooncake.store import ReplicateConfig # type: ignore
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.utils import get_ip, logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import MooncakeEngineKey
|
||||
from vllm_npu.distributed.mooncake.transfer_engine import get_global_te
|
||||
|
||||
from .config_data import MooncakeStoreConfig
|
||||
|
||||
METADATA_BYTES_LEN = 24
|
||||
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
|
||||
|
||||
|
||||
class Mooncakestore():
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank_local
|
||||
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
|
||||
if not all_device_ids:
|
||||
device_ids_list = list(
|
||||
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
|
||||
else:
|
||||
device_ids_list = list(map(int, all_device_ids.split(',')))
|
||||
assert len(device_ids_list) > tp_rank
|
||||
device_id = device_ids_list[tp_rank]
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
|
||||
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
|
||||
":npu_" + str(device_id)
|
||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
else:
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = get_global_te(local_hostname, device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def exists(self, key: MooncakeEngineKey) -> bool:
|
||||
return self.store.is_exist(key.to_string()) == 1
|
||||
|
||||
def batch_exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def register_buffer(self, ptr, length):
|
||||
return self.store.register_buffer(ptr, length)
|
||||
|
||||
def get_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(
|
||||
keys, addrs, sizes, True)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}. {e}")
|
||||
|
||||
def put_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
config = ReplicateConfig()
|
||||
config.preferred_segment = self.local_seg
|
||||
config.prefer_alloc_in_same_node = True
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys, addrs, sizes, config)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
expect_res = sum(size)
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
res = self.store.batch_get_into_ascend(key_str, addr, size)
|
||||
if res[0] != expect_res:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
except Exception:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
return res
|
||||
|
||||
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
ret = self.store.batch_put_from_ascend(key_str, addr, size)
|
||||
if ret[0] != 0:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
except Exception:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
|
||||
return ret
|
||||
|
||||
def close(self):
|
||||
self.store.close()
|
||||
logger.info("Closed the mooncake store connection")
|
||||
494
vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py
Normal file
494
vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import logger, make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
|
||||
from vllm_npu.distributed.mooncake.mooncake_engine import MooncakeEngine
|
||||
|
||||
|
||||
class MooncakeConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
|
||||
vllm_config, self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = MooncakeEngine(
|
||||
vllm_config,
|
||||
self.use_layerwise,
|
||||
)
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0 and self.kv_role != "kv_consumer":
|
||||
self.lookup_server = MooncakeLookupServer(
|
||||
self.connector_worker, vllm_config, self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._get_connector_metadata(),
|
||||
MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeStoreConnector does not do layerwise saving."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
if self.use_layerwise:
|
||||
self.connector_worker.wait_layer_transfer_finish()
|
||||
return
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
meta = self._get_connector_metadata()
|
||||
done_sending, done_recving = self.connector_worker.get_finished()
|
||||
sended_and_finished: set[str] = set()
|
||||
for item in list(self.sended_but_unfinished_reqs):
|
||||
if item not in meta.unfinished_request_ids:
|
||||
sended_and_finished.add(item)
|
||||
self.sended_but_unfinished_reqs.remove(item)
|
||||
for item in done_sending:
|
||||
if item in meta.unfinished_request_ids:
|
||||
self.sended_but_unfinished_reqs.add(item)
|
||||
else:
|
||||
sended_and_finished.add(item)
|
||||
|
||||
return sended_and_finished, done_recving
|
||||
|
||||
|
||||
def get_zmq_rpc_path_mooncake(
|
||||
vllm_config: Optional["VllmConfig"] = None, ) -> str:
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
if vllm_config is not None:
|
||||
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"mooncake_rpc_port", 0)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
|
||||
|
||||
|
||||
class MooncakeStoreConnectorV1Scheduler:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||
self.use_layerwise = use_layerwise
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.client = MooncakeLookupClient(
|
||||
vllm_config) if self.kv_role != "kv_consumer" else None
|
||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_load", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
# request_id -> (vllm cached tokes, mooncake cached tokens)
|
||||
self.load_specs: dict[str, LoadSpec] = {}
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
# request_id -> full_token_ids
|
||||
self._request_trackers: dict[str, RequestTracker] = {}
|
||||
# Whether to discard partial chunks
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True))
|
||||
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._unfinished_request_ids: set[str] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check for external KV cache hit.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
|
||||
return 0, False
|
||||
|
||||
if self._discard_partial_chunks:
|
||||
token_block_end = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
token_ids = torch.tensor(
|
||||
request.prompt_token_ids[:token_block_end])
|
||||
else:
|
||||
token_ids = torch.tensor(request.prompt_token_ids)
|
||||
|
||||
num_external_hit_tokens = self.client.lookup( # type: ignore[union-attr]
|
||||
token_ids)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
|
||||
logger.info(
|
||||
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
|
||||
request.request_id,
|
||||
request.num_tokens,
|
||||
num_external_hit_tokens,
|
||||
need_to_allocate,
|
||||
)
|
||||
|
||||
if need_to_allocate <= 0:
|
||||
return 0, False
|
||||
|
||||
self.load_specs[request.request_id] = LoadSpec(
|
||||
vllm_cached_tokens=num_computed_tokens,
|
||||
mooncake_cached_tokens=num_external_hit_tokens,
|
||||
can_load=False,
|
||||
)
|
||||
|
||||
return need_to_allocate, self.load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after temporary buffer alloc.
|
||||
|
||||
For SharedStorageConnector, update _request_needs_load
|
||||
if the CacheManager this allocated blocks for us.
|
||||
"""
|
||||
local_block_ids = []
|
||||
if num_external_tokens > 0:
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
|
||||
self._unfinished_requests[request.request_id] = (request,
|
||||
local_block_ids)
|
||||
self._unfinished_request_ids.add(request.request_id)
|
||||
if request.request_id not in self.load_specs:
|
||||
# No KV tokens from external KV cache, return
|
||||
return
|
||||
|
||||
if num_external_tokens == 0:
|
||||
# No need to load anything
|
||||
self.load_specs[request.request_id].can_load = False
|
||||
return
|
||||
|
||||
assert (
|
||||
num_external_tokens > 0 and num_external_tokens
|
||||
== self.load_specs[request.request_id].mooncake_cached_tokens -
|
||||
self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||
f" for request {request.request_id}")
|
||||
|
||||
self.load_specs[request.request_id].can_load = True
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""Attach the connector metadata to the request object.
|
||||
|
||||
This function should NOT modify other fields in the scheduler_output
|
||||
except the `kv_connector_metadata` field.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
force_skip_save = self.kv_role == "kv_consumer"
|
||||
|
||||
for finished_req_id in scheduler_output.finished_req_ids:
|
||||
self._request_trackers.pop(finished_req_id, None)
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.discard(finished_req_id)
|
||||
|
||||
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
|
||||
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
# Right now, we only load KV for new requests
|
||||
load_spec = self.load_specs.pop(request.req_id, None)
|
||||
num_tokens_to_compute = (
|
||||
request.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||
request_tracker = RequestTracker.from_new_request(
|
||||
request, num_tokens_to_compute)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
if isinstance(cached_reqs, list) and not force_skip_save:
|
||||
for i, req in enumerate(cached_reqs):
|
||||
request_tracker = self._request_trackers[req.req_id]
|
||||
request_tracker.update(req.new_token_ids, req.new_block_ids)
|
||||
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(req.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
elif not force_skip_save:
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = len(request_tracker.token_ids)
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if not new_block_ids:
|
||||
continue
|
||||
request_tracker.update(new_token_ids, new_block_ids)
|
||||
# decode not save
|
||||
if len(request_tracker.token_ids) > len(
|
||||
request.prompt_token_ids):
|
||||
continue
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
request_ids = [
|
||||
req.req_id for req in scheduler_output.scheduled_new_reqs
|
||||
]
|
||||
for request_id, (request,
|
||||
block_ids) in self._unfinished_requests.items():
|
||||
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
||||
load_spec = self.load_specs.pop(request_id, None)
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.mooncake_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size
|
||||
!= 0) and (num_tokens_to_compute
|
||||
== len(request.prompt_token_ids) - 1):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
self._request_trackers[request_id] = request_tracker
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=None,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer":
|
||||
return False, None
|
||||
tracker = self._request_trackers.get(request.request_id)
|
||||
if tracker is not None and tracker.num_saved_tokens <= 0:
|
||||
return False, None
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(block_ids), request.request_id)
|
||||
return delay_free_blocks, None
|
||||
|
||||
|
||||
class MooncakeLookupClient:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REQ, # type: ignore[attr-defined]
|
||||
bind=False,
|
||||
)
|
||||
|
||||
def lookup(self, token_ids: torch.Tensor) -> int:
|
||||
request = self.encoder.encode(token_ids)
|
||||
self.socket.send_multipart(request, copy=False)
|
||||
resp = self.socket.recv()
|
||||
result = int.from_bytes(resp, "big")
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
|
||||
|
||||
class MooncakeLookupServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mooncake_engine: MooncakeEngine,
|
||||
vllm_config: "VllmConfig",
|
||||
use_layerwise: bool,
|
||||
):
|
||||
self.decoder = MsgpackDecoder(torch.Tensor)
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REP, # type: ignore[attr-defined]
|
||||
bind=True,
|
||||
)
|
||||
|
||||
self.mooncake_engine = mooncake_engine
|
||||
self.running = True
|
||||
|
||||
def process_request():
|
||||
while self.running:
|
||||
frames = self.socket.recv_multipart(copy=False)
|
||||
token_ids = self.decoder.decode(frames)
|
||||
result = self.mooncake_engine.lookup_scheduler(
|
||||
token_ids, use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
self.thread = threading.Thread(target=process_request, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
# TODO: close the thread!
|
||||
38
vllm_npu/distributed/mooncake/transfer_engine.py
Normal file
38
vllm_npu/distributed/mooncake/transfer_engine.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import ipaddress
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
|
||||
_global_te = None
|
||||
_global_te_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_global_te(hostname: str, device_name: Optional[str]):
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
raise RuntimeError(
|
||||
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
global _global_te
|
||||
if _global_te is None:
|
||||
with _global_te_lock:
|
||||
# Double-Checked Locking
|
||||
if _global_te is None:
|
||||
if TransferEngine is None:
|
||||
raise RuntimeError("mooncake is not available")
|
||||
transfer_engine = TransferEngine()
|
||||
device_name = device_name if device_name is not None else ""
|
||||
ret_value = transfer_engine.initialize(hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"ascend", device_name)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(
|
||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||
)
|
||||
_global_te = transfer_engine
|
||||
return _global_te
|
||||
1263
vllm_npu/distributed/mooncake_connector.py
Normal file
1263
vllm_npu/distributed/mooncake_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
1153
vllm_npu/distributed/mooncake_layerwise_connector.py
Normal file
1153
vllm_npu/distributed/mooncake_layerwise_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
196
vllm_npu/distributed/parallel_state.py
Normal file
196
vllm_npu/distributed/parallel_state.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
_OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
assert _MC2 is not None, ("mc2 group is not initialized")
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_otp_group() -> GroupCoordinator:
|
||||
assert _OTP is not None, (
|
||||
"output tensor parallel group is not initialized")
|
||||
return _OTP
|
||||
|
||||
|
||||
def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
assert _LMTP is not None, (
|
||||
"lm head tensor parallel group is not initialized")
|
||||
return _LMTP
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def get_p_tp_group() -> GroupCoordinator:
|
||||
assert _P_TP is not None, (
|
||||
"distributed prefill tensor parallel group is not initialized")
|
||||
return _P_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
|
||||
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
if model_parallel_initialized():
|
||||
return
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
|
||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
global _P_TP
|
||||
assert _P_TP is None, (
|
||||
"distributed prefill tensor parallel group is already initialized")
|
||||
prefill_tensor_model_parallel_size = pd_tp_ratio
|
||||
# divide alltoall groups
|
||||
if pd_head_ratio > 1 and get_current_vllm_config(
|
||||
).kv_transfer_config.is_kv_producer:
|
||||
num_head_replica = get_ascend_config().num_head_replica
|
||||
remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio
|
||||
if num_head_replica <= 1:
|
||||
group_ranks = all_ranks.view(
|
||||
-1, prefill_tensor_model_parallel_size).unbind(0)
|
||||
else:
|
||||
group_ranks = all_ranks.clone().view(
|
||||
parallel_config.data_parallel_size, -1,
|
||||
num_head_replica) # [DP_size, num_head, num_head_replica]
|
||||
group_ranks = group_ranks.permute(0, 2, 1)
|
||||
group_ranks = group_ranks.reshape(
|
||||
-1,
|
||||
group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
||||
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
||||
group_ranks = group_ranks.unsqueeze(-1).view(
|
||||
parallel_config.data_parallel_size, num_head_replica, -1,
|
||||
alltoall_group_size
|
||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||
group_ranks = group_ranks.reshape(-1,
|
||||
alltoall_group_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
local_rank = get_world_group().local_rank
|
||||
num = next(
|
||||
(i for i, ranks in enumerate(group_ranks) if local_rank in ranks),
|
||||
None)
|
||||
_P_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name=f"p_tp_{num}")
|
||||
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
_MC2 = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mc2")
|
||||
if envs_ascend.vllm_npu_ENABLE_MLP_OPTIMIZE:
|
||||
global _MLP_TP
|
||||
assert _MLP_TP is None, (
|
||||
"mlp tensor model parallel group is already initialized")
|
||||
|
||||
mlp_tp = parallel_config.data_parallel_size
|
||||
|
||||
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
||||
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
||||
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_MLP_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
# If oproj tensor parallel size is set, we will create a group for it.
|
||||
otp_size = get_ascend_config().oproj_tensor_parallel_size
|
||||
if otp_size is not None:
|
||||
group_ranks = []
|
||||
global _OTP
|
||||
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
ranks = list(range(i * otp_size, (i + 1) * otp_size))
|
||||
group_ranks.append(ranks)
|
||||
_OTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="otp")
|
||||
|
||||
lmhead_tensor_parallel_size = get_ascend_config(
|
||||
).lmhead_tensor_parallel_size
|
||||
if lmhead_tensor_parallel_size is not None:
|
||||
group_ranks = []
|
||||
global _LMTP
|
||||
num_lmhead_tensor_parallel_groups: int = (world_size //
|
||||
lmhead_tensor_parallel_size)
|
||||
for i in range(num_lmhead_tensor_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * lmhead_tensor_parallel_size,
|
||||
(i + 1) * lmhead_tensor_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
_LMTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().world_size
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_rank():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
_MC2.destroy()
|
||||
_MC2 = None
|
||||
|
||||
global _MLP_TP
|
||||
if _MLP_TP:
|
||||
_MLP_TP.destroy()
|
||||
_MLP_TP = None
|
||||
|
||||
global _LMTP
|
||||
if _LMTP:
|
||||
_LMTP.destroy()
|
||||
_LMTP = None
|
||||
|
||||
global _OTP
|
||||
if _OTP:
|
||||
_OTP.destroy()
|
||||
_OTP = None
|
||||
|
||||
global _P_TP
|
||||
if _P_TP:
|
||||
_P_TP.destroy()
|
||||
_P_TP = None
|
||||
61
vllm_npu/distributed/utils.py
Normal file
61
vllm_npu/distributed/utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_npu.distributed.parallel_state import get_p_tp_group
|
||||
|
||||
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
value: torch.TensorType):
|
||||
if pd_tp_ratio <= 1:
|
||||
return None, None
|
||||
elif key is None or value is None:
|
||||
raise ValueError("key or value is None")
|
||||
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
|
||||
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
|
||||
return k_output, v_output
|
||||
|
||||
|
||||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||||
num_kv_heads = input_tensor.size(1)
|
||||
output_tensor = torch.zeros_like(input_tensor)
|
||||
dist.all_to_all_single(output_tensor,
|
||||
input_tensor,
|
||||
group=get_p_tp_group().device_group)
|
||||
input_tensor = 0
|
||||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||||
output_tensor = 0
|
||||
return result
|
||||
|
||||
|
||||
def rearrange_output(base_output: torch.Tensor, cut_num: int,
|
||||
num_kv_heads: int):
|
||||
size_0 = base_output.size(0)
|
||||
if size_0 % cut_num != 0:
|
||||
raise ValueError(
|
||||
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
|
||||
)
|
||||
chunk_size = size_0 // cut_num
|
||||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||||
transposed = reshaped.transpose(0, 1)
|
||||
return transposed.contiguous().view(size_0, num_kv_heads, -1)
|
||||
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
|
||||
def get_transfer_timeout_value():
|
||||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||||
if len(ascend_transfer_timeout) > 0:
|
||||
return int(ascend_transfer_timeout)
|
||||
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
|
||||
'20')) # type: ignore
|
||||
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
|
||||
'7')) # type: ignore
|
||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
||||
3000)
|
||||
183
vllm_npu/envs.py
Normal file
183
vllm_npu/envs.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the used env vars.
|
||||
|
||||
# begin-env-vars-definition
|
||||
|
||||
env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# max compile thread number for package building. Usually, it is set to
|
||||
# the number of CPU cores. If not set, the default value is None, which
|
||||
# means all number of CPU cores will be used.
|
||||
"MAX_JOBS":
|
||||
lambda: os.getenv("MAX_JOBS", None),
|
||||
# The build type of the package. It can be one of the following values:
|
||||
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
||||
"CMAKE_BUILD_TYPE":
|
||||
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
||||
# Whether to compile custom kernels. If not set, the default value is True.
|
||||
# If set to False, the custom kernels will not be compiled. Please note that
|
||||
# the sleep mode feature will be disabled as well if custom kernels are not
|
||||
# compiled.
|
||||
"COMPILE_CUSTOM_KERNELS":
|
||||
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
||||
# The CXX compiler used for compiling the package. If not set, the default
|
||||
# value is None, which means the system default CXX compiler will be used.
|
||||
"CXX_COMPILER":
|
||||
lambda: os.getenv("CXX_COMPILER", None),
|
||||
# The C compiler used for compiling the package. If not set, the default
|
||||
# value is None, which means the system default C compiler will be used.
|
||||
"C_COMPILER":
|
||||
lambda: os.getenv("C_COMPILER", None),
|
||||
# The version of the Ascend chip. If not set, the default value is
|
||||
# ASCEND910B1(Available for A2 and A3 series). It's used for package building.
|
||||
# Please make sure that the version is correct.
|
||||
"SOC_VERSION":
|
||||
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
||||
# If set, vllm-ascend will print verbose logs during compilation
|
||||
"VERBOSE":
|
||||
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
||||
# The home path for CANN toolkit. If not set, the default value is
|
||||
# /usr/local/Ascend/ascend-toolkit/latest
|
||||
"ASCEND_HOME_PATH":
|
||||
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
||||
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
||||
# not set, the default value is libhccl.so。
|
||||
"HCCL_SO_PATH":
|
||||
lambda: os.environ.get("HCCL_SO_PATH", None),
|
||||
# The version of vllm is installed. This value is used for developers who
|
||||
# installed vllm from source locally. In this case, the version of vllm is
|
||||
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
||||
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
||||
# In this case, developers need to set this value to "0.9.0" to make sure
|
||||
# that the correct package is installed.
|
||||
"VLLM_VERSION":
|
||||
lambda: os.getenv("VLLM_VERSION", None),
|
||||
# Whether to enable the trace recompiles from pytorch.
|
||||
"vllm_npu_TRACE_RECOMPILES":
|
||||
lambda: bool(int(os.getenv("vllm_npu_TRACE_RECOMPILES", '0'))),
|
||||
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
||||
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
||||
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
||||
),
|
||||
# Whether to enable DBO feature for deepseek model.
|
||||
"vllm_npu_ENABLE_DBO":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_DBO", '0'))),
|
||||
# Whether to enable the model execute time observe profile. Disable it when
|
||||
# running vllm ascend in production environment.
|
||||
"vllm_npu_MODEL_EXECUTE_TIME_OBSERVE":
|
||||
lambda: bool(int(os.getenv("vllm_npu_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
||||
),
|
||||
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
||||
# training, the optimized model may not be suitable. In this case, set this
|
||||
# value to False to disable the optimized model.
|
||||
"USE_OPTIMIZED_MODEL":
|
||||
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
||||
# The tolerance of the kv cache size, if the difference between the
|
||||
# actual kv cache size and the cached kv cache size is less than this value,
|
||||
# then the cached kv cache size will be used.
|
||||
"vllm_npu_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
||||
lambda: int(
|
||||
os.getenv("vllm_npu_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
||||
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
|
||||
# We'll remove this flag in the future once it's stable enough.
|
||||
"vllm_npu_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
||||
lambda: bool(
|
||||
int(os.getenv("vllm_npu_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))),
|
||||
# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
|
||||
# used for llmdatadist to build the communication topology for kv cache transfer, it is
|
||||
# a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated
|
||||
# pd. The rank table can be generated by adopting the script `gen_ranktable.sh`
|
||||
# in vllm_npu's example folder.
|
||||
"DISAGGREGATED_PREFILL_RANK_TABLE_PATH":
|
||||
lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None),
|
||||
# `LLMDataDistCMgrConnector` required variable. `vllm_npu_LLMDD_RPC_IP` is used as the
|
||||
# rpc communication listening ip, which will be used to receive the agent metadata from the
|
||||
# remote worker.
|
||||
"vllm_npu_LLMDD_RPC_IP":
|
||||
lambda: os.getenv("vllm_npu_LLMDD_RPC_IP", "0.0.0.0"),
|
||||
# `LLMDataDistCMgrConnector` required variable. `vllm_npu_LLMDD_RPC_PORT` is used as the
|
||||
# rpc communication listening port, which will be used to receive the agent metadata from the
|
||||
# remote worker.
|
||||
"vllm_npu_LLMDD_RPC_PORT":
|
||||
lambda: int(os.getenv("vllm_npu_LLMDD_RPC_PORT", 5557)),
|
||||
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
|
||||
# and the mla_pa will be the default path of deepseek decode path.
|
||||
"vllm_npu_MLA_PA":
|
||||
lambda: int(os.getenv("vllm_npu_MLA_PA", 0)),
|
||||
# Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled.
|
||||
# this feature is supported in A2, and eager mode will get better performance.
|
||||
"vllm_npu_ENABLE_MATMUL_ALLREDUCE":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
# Whether to enable FlashComm optimization when tensor parallel is enabled.
|
||||
# This feature will get better performance when concurrency is large.
|
||||
"vllm_npu_ENABLE_FLASHCOMM1":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_FLASHCOMM1", '0'))),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"vllm_npu_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_PREFETCH_MLP", '0'))),
|
||||
# buffer size for gate up prefetch
|
||||
"vllm_npu_MLP_GATE_UP_PREFETCH_SIZE":
|
||||
lambda: int(
|
||||
os.getenv("vllm_npu_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
||||
# buffer size for down proj prefetch
|
||||
"vllm_npu_MLP_DOWN_PREFETCH_SIZE":
|
||||
lambda: int(
|
||||
os.getenv("vllm_npu_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
|
||||
# Whether to enable dense model and general optimizations for better performance.
|
||||
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
|
||||
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
|
||||
"vllm_npu_ENABLE_DENSE_OPTIMIZE":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_DENSE_OPTIMIZE", '0'))),
|
||||
# Whether to enable mlp optimize when tensor parallel is enabled.
|
||||
# this feature in eager mode will get better performance.
|
||||
"vllm_npu_ENABLE_MLP_OPTIMIZE":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_MLP_OPTIMIZE", '0'))),
|
||||
# Determine the number of physical devices in a non-full-use scenario
|
||||
# caused by the initialization of the Mooncake connector.
|
||||
"PHYSICAL_DEVICES":
|
||||
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
||||
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
||||
"MSMONITOR_USE_DAEMON":
|
||||
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
|
||||
"vllm_npu_ENABLE_MLAPO":
|
||||
lambda: bool(int(os.getenv("vllm_npu_ENABLE_MLAPO", '0'))),
|
||||
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
|
||||
"vllm_npu_ENABLE_NZ":
|
||||
lambda: int(os.getenv("vllm_npu_ENABLE_NZ", 1)),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# lazy evaluation of environment variables
|
||||
if name in env_variables:
|
||||
return env_variables[name]()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
0
vllm_npu/eplb/__init__.py
Normal file
0
vllm_npu/eplb/__init__.py
Normal file
0
vllm_npu/eplb/adaptor/__init__.py
Normal file
0
vllm_npu/eplb/adaptor/__init__.py
Normal file
44
vllm_npu/eplb/adaptor/abstract_adaptor.py
Normal file
44
vllm_npu/eplb/adaptor/abstract_adaptor.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class EplbAdaptor():
|
||||
|
||||
def __init__(self, **args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_rank_expert_workload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_init_expert_map(self, num_moe_layers: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def do_update_expert_map(self, layer_id: Any,
|
||||
updated_expert_map: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def do_update_expert_weight(self, layer_id: Any,
|
||||
local_expert_to_replace: Any,
|
||||
buffer_tensor_id: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
289
vllm_npu/eplb/adaptor/vllm_adaptor.py
Normal file
289
vllm_npu/eplb/adaptor/vllm_adaptor.py
Normal file
@@ -0,0 +1,289 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.eplb.adaptor.abstract_adaptor import EplbAdaptor
|
||||
|
||||
|
||||
class VllmEplbAdaptor(EplbAdaptor):
|
||||
|
||||
def __init__(self, model, **args):
|
||||
super().__init__(**args)
|
||||
self.model = model
|
||||
self.rank_id = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.param_dict = dict(self.model.named_parameters())
|
||||
if self.model.config.model_type == "qwen3_moe":
|
||||
self.num_dense_layers = 0
|
||||
self.global_expert_num = self.model.config.num_experts
|
||||
else:
|
||||
self.num_dense_layers = self.model.config.first_k_dense_replace
|
||||
self.global_expert_num = self.model.config.n_routed_experts
|
||||
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
|
||||
self.init_redundancy_expert = get_ascend_config(
|
||||
).init_redundancy_expert
|
||||
|
||||
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
||||
if self.model.quant_config is not None:
|
||||
self.expert_weight_names = [
|
||||
"w13_weight", "w2_weight", "w13_weight_scale",
|
||||
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
|
||||
]
|
||||
else:
|
||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||
|
||||
self.expert_map_per_layer = dict(
|
||||
) # reference to expert map on device for expert map update
|
||||
self.expert_map_per_layer_cpu = dict(
|
||||
) # copy of expert map on CPU to avoid device synchronize frequently
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
|
||||
self.model.get_expert_map(self.num_dense_layers + layer_idx)
|
||||
|
||||
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
|
||||
num_buffer_tensor = torch.where(
|
||||
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
|
||||
self.buffer_tensor_list: list[list[Any]] = [
|
||||
[] for _ in range(num_buffer_tensor)
|
||||
]
|
||||
self.init_buffer_tensor(num_buffer_tensor)
|
||||
|
||||
self.expert_param_per_layer = dict()
|
||||
self.init_expert_param_per_layer()
|
||||
|
||||
self.log2phy_map_per_layer = dict()
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
|
||||
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
|
||||
|
||||
self.all_topk_ids = []
|
||||
|
||||
def init_buffer_tensor(self, num_buffer_tensor):
|
||||
for buffer_id in range(num_buffer_tensor):
|
||||
for name in self.expert_weight_names:
|
||||
complete_name = "model.layers." + str(
|
||||
self.num_dense_layers) + ".mlp.experts." + name
|
||||
expert_tensor = self.param_dict[complete_name].data[0]
|
||||
if name in ["w13_weight", "w2_weight"]:
|
||||
expert_tensor = expert_tensor.clone()
|
||||
buffer_tensor = torch.empty_like(expert_tensor)
|
||||
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
|
||||
|
||||
def init_expert_param_per_layer(self):
|
||||
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
|
||||
".mlp.experts." + self.expert_weight_names[0]].data.shape[0]
|
||||
for moe_layer_id in range(self.num_moe_layers):
|
||||
layer_idx = self.num_dense_layers + moe_layer_id
|
||||
self.expert_param_per_layer[layer_idx] = list()
|
||||
for local_expert_id in range(num_local_expert):
|
||||
self.expert_param_per_layer[layer_idx].append([
|
||||
self.param_dict["model.layers." + str(layer_idx) +
|
||||
".mlp.experts." +
|
||||
name].data[local_expert_id]
|
||||
for name in self.expert_weight_names
|
||||
])
|
||||
|
||||
def get_rank_expert_workload(self) -> torch.Tensor:
|
||||
self.moe_load = self.model.get_all_moe_loads()
|
||||
return self.moe_load
|
||||
|
||||
def get_init_expert_map(self, num_moe_layers):
|
||||
expert_map = self.model.get_all_expert_map(num_moe_layers)
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
gathered = torch.empty(
|
||||
(world_size, *expert_map.shape), # [W, L, E]
|
||||
dtype=expert_map.dtype,
|
||||
device=expert_map.device)
|
||||
|
||||
dist.all_gather_into_tensor(gathered, expert_map)
|
||||
all_maps = gathered.permute(1, 0, 2)
|
||||
all_expert_maps = all_maps.cpu()
|
||||
|
||||
for layer_idx in range(num_moe_layers):
|
||||
self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \
|
||||
all_expert_maps[layer_idx][self.rank_id]
|
||||
|
||||
return all_expert_maps
|
||||
|
||||
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
|
||||
|
||||
try:
|
||||
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(
|
||||
expert_map_path)
|
||||
expert_map_all = self.local2global(expert_map_tensor)
|
||||
except (TypeError, FileNotFoundError, OSError):
|
||||
expert_map_all = self.determine_expert_map_all()
|
||||
|
||||
for layer_idx in range(num_moe_layers):
|
||||
if self.model.config.model_type == "qwen3_moe":
|
||||
self.expert_map_per_layer_cpu[layer_idx] = \
|
||||
expert_map_all[layer_idx][self.rank_id]
|
||||
else:
|
||||
self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \
|
||||
expert_map_all[layer_idx][self.rank_id]
|
||||
return expert_map_all
|
||||
|
||||
def _expert_file_to_tensor(self, expert_map_path: str):
|
||||
with open(expert_map_path, "r") as f:
|
||||
data = json.load(f)
|
||||
layers_num = data["moe_layer_count"]
|
||||
gpus_num = data["layer_list"][0]["device_count"]
|
||||
|
||||
tensor_data = []
|
||||
for layer in data["layer_list"]:
|
||||
device_data = []
|
||||
for device in layer["device_list"]:
|
||||
device_data.append(device["device_expert"])
|
||||
tensor_data.append(device_data)
|
||||
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
|
||||
return expert_map_tensor, layers_num, gpus_num
|
||||
logger.error(f"failed to read expert_map_path: {expert_map_path}")
|
||||
|
||||
def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
|
||||
if self.rank_id == 0:
|
||||
num_local_experts = expert_maps.max() + 1
|
||||
expert_maps_local = self.global2local(expert_maps,
|
||||
num_local_experts)
|
||||
|
||||
expert_maps_list = expert_maps_local.tolist()
|
||||
record: dict[str, Any] = {
|
||||
"moe_layer_count": len(expert_maps_list),
|
||||
"layer_list": []
|
||||
}
|
||||
|
||||
for layer_idx, layer_data in enumerate(expert_maps_list):
|
||||
layer_record: dict[str, Any] = {
|
||||
"layer_id": layer_idx,
|
||||
"device_count": len(layer_data),
|
||||
"device_list": []
|
||||
}
|
||||
|
||||
for device_idx, experts in enumerate(layer_data):
|
||||
device_record = {
|
||||
"device_id": device_idx,
|
||||
"device_expert": experts
|
||||
}
|
||||
layer_record["device_list"].append(device_record)
|
||||
|
||||
record["layer_list"].append(layer_record)
|
||||
|
||||
with open(expert_map_record_path, "w") as f:
|
||||
json.dump(record, f, indent=4)
|
||||
|
||||
def do_update_expert_map(self, layer_id, updated_expert_map):
|
||||
self.expert_map_per_layer[layer_id].copy_(updated_expert_map)
|
||||
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
|
||||
|
||||
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
|
||||
buffer_tensor_id):
|
||||
for expert_tensor, buffer_tensor in zip(
|
||||
self.expert_param_per_layer[layer_id][local_expert_to_replace],
|
||||
self.buffer_tensor_list[buffer_tensor_id]):
|
||||
expert_tensor.copy_(buffer_tensor)
|
||||
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
|
||||
|
||||
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
|
||||
if self.log2phy_map_per_layer[layer_id] is not None:
|
||||
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
|
||||
|
||||
def global2local(self, placement: torch.Tensor,
|
||||
E_local: int) -> torch.Tensor:
|
||||
|
||||
L, G, _ = placement.shape
|
||||
device = placement.device
|
||||
|
||||
pt_local = torch.full((L, G, E_local),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement >= 0
|
||||
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
|
||||
|
||||
slot_idx = placement[l_idx, g_idx, k_idx]
|
||||
|
||||
pt_local[l_idx, g_idx, slot_idx] = k_idx
|
||||
|
||||
return pt_local
|
||||
|
||||
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
L, G, E_local = placement_local.shape
|
||||
device = placement_local.device
|
||||
|
||||
max_id = torch.max(placement_local)
|
||||
E_global = (max_id + 1).item() if max_id >= 0 else 0
|
||||
|
||||
if E_global == 0:
|
||||
return torch.empty((L, G, 0), dtype=torch.long, device=device)
|
||||
|
||||
placement_global = torch.full((L, G, E_global),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement_local >= 0
|
||||
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
|
||||
gid_idx = placement_local[l_idx, g_idx, slot_idx]
|
||||
|
||||
placement_global[l_idx, g_idx, gid_idx] = slot_idx
|
||||
|
||||
return placement_global
|
||||
|
||||
def determine_expert_map_all(self):
|
||||
if self.world_size == 1:
|
||||
local_ids = torch.arange(self.global_expert_num, dtype=torch.int32)
|
||||
return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1)
|
||||
|
||||
local_num_experts = self.global_expert_num // self.world_size
|
||||
|
||||
expert_map_all = torch.full(
|
||||
(self.num_moe_layers, self.world_size, self.global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
|
||||
for r in range(self.world_size):
|
||||
if r < self.world_size - 1:
|
||||
start = r * local_num_experts
|
||||
end = (r + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = r * local_num_experts
|
||||
end = self.global_expert_num
|
||||
local_count = self.global_expert_num - r * local_num_experts
|
||||
|
||||
if r < self.init_redundancy_expert:
|
||||
local_count += 1
|
||||
if end < self.global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
|
||||
self.num_moe_layers, -1)
|
||||
|
||||
return expert_map_all
|
||||
0
vllm_npu/eplb/core/__init__.py
Normal file
0
vllm_npu/eplb/core/__init__.py
Normal file
134
vllm_npu/eplb/core/eplb_device_transfer_loader.py
Normal file
134
vllm_npu/eplb/core/eplb_device_transfer_loader.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from enum import Enum
|
||||
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class ExpertWeightUpdateState(Enum):
|
||||
WAITING = 0 # waiting for updated expert_map by EplbWorker
|
||||
READY = 1 # ready for d2d expert weights updating
|
||||
TRANSFERRING = 2 # d2d finished and waiting for updating expert_map into model
|
||||
|
||||
|
||||
class D2DExpertWeightLoader:
|
||||
|
||||
def __init__(self):
|
||||
self.comm_op_list = None
|
||||
self.updated_expert_map = None
|
||||
self.updated_log2phy_map = None
|
||||
self.layer_id = -1 # layer id to be updated
|
||||
self.state = ExpertWeightUpdateState.WAITING
|
||||
self.recv_expert_list = []
|
||||
self.mock_flag = True
|
||||
|
||||
def set_adator(self, eplb_adaptor):
|
||||
self.eplb_adaptor = eplb_adaptor
|
||||
|
||||
def generate_expert_d2d_transfer_task(self, expert_send_info,
|
||||
expert_recv_info, updated_expert_map,
|
||||
layer_id):
|
||||
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
|
||||
if self.state != ExpertWeightUpdateState.WAITING:
|
||||
logger.warning_once(
|
||||
"current d2d weight update tasks are on-going, cannot accept new weight update task"
|
||||
)
|
||||
return
|
||||
|
||||
self.updated_expert_map = updated_expert_map
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.comm_op_list = []
|
||||
for send_info in expert_send_info:
|
||||
dst_rank, global_expert_id_to_send = send_info
|
||||
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[
|
||||
layer_id][global_expert_id_to_send].item()
|
||||
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
|
||||
layer_id][local_expert_id]:
|
||||
src_tensor = src_tensor.clone()
|
||||
self.comm_op_list.append(
|
||||
dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||
|
||||
buffer_tensor_id = 0
|
||||
for recv_info in expert_recv_info:
|
||||
recv_rank, global_expert_id_to_recv = recv_info
|
||||
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[
|
||||
buffer_tensor_id]:
|
||||
self.comm_op_list.append(
|
||||
dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
|
||||
local_expert_to_replace = self.updated_expert_map[
|
||||
global_expert_id_to_recv].item()
|
||||
self.recv_expert_list.append(
|
||||
(local_expert_to_replace, buffer_tensor_id))
|
||||
buffer_tensor_id += 1
|
||||
|
||||
self.state = ExpertWeightUpdateState.READY
|
||||
|
||||
def set_log2phy_map(self, log2phy_map):
|
||||
self.updated_log2phy_map = log2phy_map
|
||||
|
||||
def asyn_expert_weight_transfer(self, reqs):
|
||||
# Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched
|
||||
if self.state != ExpertWeightUpdateState.READY:
|
||||
return
|
||||
|
||||
# set asynchronous stream for d2d expert weight transfer
|
||||
if self.comm_op_list:
|
||||
ret_list = dist.batch_isend_irecv(self.comm_op_list)
|
||||
reqs.extend(ret_list)
|
||||
|
||||
self.state = ExpertWeightUpdateState.TRANSFERRING
|
||||
|
||||
def update_expert_map_and_weight(self, reqs):
|
||||
# Only after send/recv tasks have been luanched, expert_map and weight can be updated
|
||||
if self.state != ExpertWeightUpdateState.TRANSFERRING:
|
||||
return
|
||||
|
||||
# Waiting for send/recv tasks finish
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
if self.comm_op_list is not None:
|
||||
self.comm_op_list = None
|
||||
|
||||
# update expert_map
|
||||
self.eplb_adaptor.do_update_expert_map(self.layer_id,
|
||||
self.updated_expert_map)
|
||||
|
||||
# update log2phy_map
|
||||
self.eplb_adaptor.do_update_log2phy_map(self.layer_id,
|
||||
self.updated_log2phy_map)
|
||||
|
||||
# update expert weight
|
||||
buffer_tensor_id = 0
|
||||
for recv_expert_info in self.recv_expert_list:
|
||||
local_expert_to_replace, buffer_tensor_id = recv_expert_info
|
||||
self.eplb_adaptor.do_update_expert_weight(self.layer_id,
|
||||
local_expert_to_replace,
|
||||
buffer_tensor_id)
|
||||
|
||||
logger.info(
|
||||
f"[EPLB] finished update expert weight for layer: {self.layer_id}")
|
||||
|
||||
self.recv_expert_list = []
|
||||
self.updated_expert_map = None
|
||||
self.layer_id = -1
|
||||
self.state = ExpertWeightUpdateState.WAITING
|
||||
|
||||
def load_impl(self, old_expert_table, new_expert_table):
|
||||
raise NotImplementedError
|
||||
189
vllm_npu/eplb/core/eplb_utils.py
Normal file
189
vllm_npu/eplb/core/eplb_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
|
||||
import os.path
|
||||
import random
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def determine_default_expert_map(global_expert_num, world_size, rank_id,
|
||||
global_redundant_expert_num):
|
||||
if world_size == 1:
|
||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
||||
return (global_expert_num, local_ids)
|
||||
|
||||
local_num_experts = global_expert_num // world_size
|
||||
|
||||
expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32)
|
||||
|
||||
if rank_id < world_size - 1:
|
||||
start = rank_id * local_num_experts
|
||||
end = (rank_id + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = rank_id * local_num_experts
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - rank_id * local_num_experts
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map[start:end] = local_ids
|
||||
|
||||
return (local_count, expert_map)
|
||||
|
||||
|
||||
def generate_log2phy_map(expert_map):
|
||||
num_local_experts = expert_map.max() + 1
|
||||
log2phy_map = expert_map.clone()
|
||||
num_ranks, num_global_expert = log2phy_map.shape
|
||||
|
||||
row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \
|
||||
num_global_expert) * num_local_experts
|
||||
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]
|
||||
|
||||
for idx in range(num_global_expert):
|
||||
positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0]
|
||||
negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0]
|
||||
num_rank_holding_expert = positive_rank_idx.size(0)
|
||||
|
||||
if num_rank_holding_expert == 0:
|
||||
log2phy_map[:, idx] = torch.full((num_ranks, ),
|
||||
0,
|
||||
dtype=log2phy_map.dtype)
|
||||
|
||||
if num_rank_holding_expert == 1:
|
||||
log2phy_map[negative_rank_idx, idx] = torch.full(
|
||||
(num_ranks - 1, ),
|
||||
log2phy_map[positive_rank_idx, idx].item(),
|
||||
dtype=log2phy_map.dtype)
|
||||
else:
|
||||
try:
|
||||
random_list = [
|
||||
random.choice(log2phy_map[positive_rank_idx, idx])
|
||||
for _ in range(num_ranks - num_rank_holding_expert)
|
||||
]
|
||||
log2phy_map[negative_rank_idx,
|
||||
idx] = torch.tensor(random_list,
|
||||
dtype=log2phy_map.dtype)
|
||||
except Exception as e:
|
||||
logger.error(f"Fail to get log2phy_map: {str(e)}")
|
||||
|
||||
return log2phy_map
|
||||
|
||||
|
||||
def determine_default_log2phy_map(global_expert_num, world_size, rank_id):
|
||||
if world_size == 1:
|
||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
||||
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
|
||||
log2phy_map_all = generate_log2phy_map(expert_map_all)
|
||||
return log2phy_map_all[rank_id]
|
||||
|
||||
local_num_experts = global_expert_num // world_size
|
||||
|
||||
expert_map_all = torch.full((world_size, global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
|
||||
for r in range(world_size):
|
||||
if r < world_size - 1:
|
||||
start = r * local_num_experts
|
||||
end = (r + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = r * local_num_experts
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - r * local_num_experts
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[r, start:end] = local_ids
|
||||
|
||||
log2phy_map_all = generate_log2phy_map(expert_map_all)
|
||||
|
||||
return log2phy_map_all[rank_id]
|
||||
|
||||
|
||||
class EPLBParamUtils:
|
||||
|
||||
@staticmethod
|
||||
def check_iterations(iterations):
|
||||
if not isinstance(iterations, int):
|
||||
raise TypeError(f"The {iterations} is not int.")
|
||||
if iterations <= 0:
|
||||
raise ValueError(
|
||||
f"The {iterations} can not less than or equal to 0.")
|
||||
if iterations > sys.maxsize:
|
||||
raise ValueError(
|
||||
f"The {iterations} can not large than {sys.maxsize}")
|
||||
|
||||
@staticmethod
|
||||
def check_dynamic_eplb(dynamic_eplb):
|
||||
if dynamic_eplb is None:
|
||||
return
|
||||
if not isinstance(dynamic_eplb, bool):
|
||||
raise TypeError("The dynamic_eplb is not bool.")
|
||||
if dynamic_eplb and os.getenv("DYNAMIC_EPLB", "false") != "true":
|
||||
raise ValueError(
|
||||
'Can not enable dynamic_eplb when not export DYNAMIC_EPLB="true".'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_expert_map_path(expert_map):
|
||||
if expert_map is None:
|
||||
return
|
||||
if not isinstance(expert_map, str):
|
||||
raise TypeError("The expert_map is not str.")
|
||||
if not expert_map.strip():
|
||||
raise ValueError("The expert_map is not empty.")
|
||||
_, ext = os.path.splitext(expert_map)
|
||||
if ext.lower() != ".json":
|
||||
raise TypeError("The expert_map is not json.")
|
||||
if not os.path.exists(expert_map):
|
||||
raise ValueError("The expert_map is not exist.")
|
||||
try:
|
||||
with open(expert_map, "w", encoding='utf-8') as f:
|
||||
f.read()
|
||||
except Exception as e:
|
||||
raise IOError(
|
||||
f"Fail read expert info from {expert_map}, please check the reading permission of {expert_map} : {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_expert_map_record_path(expert_map_record_path):
|
||||
if expert_map_record_path is None:
|
||||
return
|
||||
if not isinstance(expert_map_record_path, str):
|
||||
raise TypeError("The expert_map_record_path is not str.")
|
||||
if not expert_map_record_path.strip():
|
||||
raise ValueError("The expert_map_record_path is empty.")
|
||||
_, ext = os.path.splitext(expert_map_record_path)
|
||||
if ext.lower() != ".json":
|
||||
raise TypeError("The expert_map_record_path is not json.")
|
||||
if os.getenv("EXPERT_MAP_RECORD", "false") != "true":
|
||||
raise ValueError(
|
||||
'Can not enable expert_map_record_path when not export EXPERT_MAP_RECORD="true".'
|
||||
)
|
||||
try:
|
||||
with open(expert_map_record_path, "w", encoding='utf-8') as f:
|
||||
f.write("")
|
||||
except Exception as e:
|
||||
raise IOError(
|
||||
f"Fail write expert info to {expert_map_record_path}, please check the writing permission of {expert_map_record_path} : {e}"
|
||||
)
|
||||
440
vllm_npu/eplb/core/eplb_worker.py
Normal file
440
vllm_npu/eplb/core/eplb_worker.py
Normal file
@@ -0,0 +1,440 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx # type: ignore
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.eplb.core.eplb_utils import generate_log2phy_map
|
||||
from vllm_npu.eplb.core.policy.policy_factory import (DynamicConfig,
|
||||
PolicyFactory)
|
||||
|
||||
|
||||
class EplbWorker:
|
||||
|
||||
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
|
||||
self.policy_type = policy_type
|
||||
self.policy = PolicyFactory.generate_policy(policy_type,
|
||||
DynamicConfig())
|
||||
self.shared_dict = shared_dict
|
||||
self.old_expert_maps = None
|
||||
self.enable_d2d = enable_d2d
|
||||
self.rank_id = dist.get_rank()
|
||||
|
||||
def do_update(self):
|
||||
# put data in to queue
|
||||
# in process self.policy.generate_policy()
|
||||
# get epxert table && tensor
|
||||
|
||||
# async stream
|
||||
# D2D
|
||||
# H2D
|
||||
# Get initial expert_map
|
||||
torch.set_num_threads(1)
|
||||
if self.old_expert_maps is None:
|
||||
self.old_expert_maps = self.get_init_expert_maps()
|
||||
if self.old_expert_maps is not None:
|
||||
self.num_local_experts = self.old_expert_maps.max() + 1
|
||||
else:
|
||||
raise ValueError("Failed to get expert_maps from shared_dict.")
|
||||
|
||||
# Get MOE load information
|
||||
load_info = self.fetch_and_sum_load_info()
|
||||
if load_info is None:
|
||||
return
|
||||
|
||||
# Get the updated expert table based on the workload information
|
||||
old_placement = self.global2local(self.old_expert_maps,
|
||||
self.num_local_experts)
|
||||
_, _, new_placement = self.calculate_rebalance_experts(
|
||||
load_info, old_placement)
|
||||
|
||||
if not torch.is_tensor(new_placement):
|
||||
new_placement = torch.tensor(new_placement)
|
||||
self.check_expert_placement(old_placement, new_placement)
|
||||
new_expert_maps = self.local2global(new_placement)
|
||||
self.update_expert_map(new_expert_maps)
|
||||
|
||||
if self.policy_type == 2:
|
||||
update_info = self.compose_expert_update_info_bipartite(
|
||||
new_expert_maps, self.old_expert_maps)
|
||||
else:
|
||||
update_info = self.compose_expert_update_info_greedy(
|
||||
new_expert_maps, self.old_expert_maps)
|
||||
self.old_expert_maps = new_expert_maps
|
||||
logger.info("EPLB Process compute complete")
|
||||
|
||||
packed_update_info = self.pack_update_info(update_info)
|
||||
|
||||
return packed_update_info
|
||||
|
||||
def check_expert_placement(self, old_placement, new_placement):
|
||||
num_layers = old_placement.shape[0]
|
||||
num_ranks = old_placement.shape[1]
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
# check if any logical expert is not placed on any rank
|
||||
if torch.unique(new_placement[layer_id]).numel() < torch.unique(
|
||||
old_placement[layer_id]).numel():
|
||||
logger.error(
|
||||
f"There exists expert not placed on any rank in layer {layer_id}"
|
||||
)
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
continue
|
||||
|
||||
for rank_id in range(num_ranks):
|
||||
new_placement_check = new_placement[layer_id][rank_id]
|
||||
old_placement_check = old_placement[layer_id][rank_id]
|
||||
|
||||
# check if same logical experts are placed on the same NPU
|
||||
if new_placement_check.numel() != torch.unique(
|
||||
new_placement_check).numel():
|
||||
logger.error(
|
||||
f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
|
||||
)
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
break
|
||||
|
||||
# check if there is any experts movement inside one NPU
|
||||
expert_not_move = torch.isin(new_placement_check,
|
||||
old_placement_check)
|
||||
if not torch.equal(new_placement_check[expert_not_move],
|
||||
old_placement_check[expert_not_move]):
|
||||
logger.error(
|
||||
f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
|
||||
)
|
||||
new_placement[layer_id] = old_placement[layer_id]
|
||||
break
|
||||
|
||||
def compose_expert_update_info_bipartite(self, updated_expert_maps_org,
|
||||
current_expert_maps_org):
|
||||
# transform numpy array to torch tensor
|
||||
updated_expert_maps = updated_expert_maps_org.clone()
|
||||
current_expert_maps = current_expert_maps_org.clone()
|
||||
updated_expert_maps = np.array(updated_expert_maps)
|
||||
current_expert_maps = np.array(current_expert_maps)
|
||||
|
||||
num_layers = current_expert_maps.shape[0]
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
||||
current_expert_maps_this_layer = current_expert_maps[layer_id]
|
||||
updated_expert_maps_this_layer_org = updated_expert_maps_org[
|
||||
layer_id]
|
||||
|
||||
from typing import Any
|
||||
|
||||
expert_send_info_this_layer: dict[Any, Any] = {}
|
||||
expert_recv_info_this_layer: dict[Any, Any] = {}
|
||||
|
||||
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
||||
if (np.equal(updated_expert_maps_this_layer,
|
||||
current_expert_maps_this_layer)).all():
|
||||
yield (expert_send_info_this_layer,
|
||||
expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer_org, layer_id)
|
||||
|
||||
# Parse expert_ids each rank needs to receive from other ranks
|
||||
dst_rank_indices, experts_to_recv = np.where(
|
||||
(current_expert_maps_this_layer == -1)
|
||||
& (updated_expert_maps_this_layer != -1))
|
||||
|
||||
# record src ranks for potential transfer
|
||||
src_ranks_set = dict()
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
if expert_id not in src_ranks_set:
|
||||
src_ranks_set[expert_id] = np.where(
|
||||
current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
||||
|
||||
# loop until all experts are scheduled
|
||||
while len(dst_rank_indices) > 0:
|
||||
# construct bipartite graph
|
||||
graph_expert_update: nx.Graph = nx.Graph()
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
dst_rank_id = dst_rank_indices[idx].item()
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
# add src ranks
|
||||
src_rank_ids = src_ranks_set[expert_id]
|
||||
graph_expert_update.add_nodes_from(src_rank_ids,
|
||||
bipartite=0)
|
||||
# add dest rank
|
||||
graph_expert_update.add_node(str(dst_rank_id), bipartite=1)
|
||||
# add edges
|
||||
for src_rank_id in src_rank_ids:
|
||||
graph_expert_update.add_edge(src_rank_id,
|
||||
str(dst_rank_id))
|
||||
|
||||
# graph may not be connected
|
||||
connected_components = list(
|
||||
nx.connected_components(graph_expert_update))
|
||||
all_matches = {}
|
||||
# matching in this loop
|
||||
for i, component in enumerate(connected_components):
|
||||
subgraph = graph_expert_update.subgraph(component)
|
||||
component_matching = nx.bipartite.maximum_matching(
|
||||
subgraph)
|
||||
all_matches.update(component_matching)
|
||||
|
||||
for src_rank, dst_rank in all_matches.items():
|
||||
dst_rank = int(dst_rank)
|
||||
assert src_rank != dst_rank
|
||||
if graph_expert_update.nodes[src_rank]['bipartite'] == 0:
|
||||
# currently not scheduled experts in rank dst_rank
|
||||
experts_v = experts_to_recv[np.where(
|
||||
dst_rank_indices == dst_rank)]
|
||||
# src: src_rank, dest: dst_rank, expert: expert_id
|
||||
expert_id = np.intersect1d(
|
||||
experts_v,
|
||||
np.where(current_expert_maps_this_layer[src_rank]
|
||||
!= -1))[0]
|
||||
|
||||
# record send/rcv pairs
|
||||
if src_rank not in expert_send_info_this_layer:
|
||||
expert_send_info_this_layer[src_rank] = []
|
||||
if dst_rank not in expert_recv_info_this_layer:
|
||||
expert_recv_info_this_layer[dst_rank] = []
|
||||
expert_send_info_this_layer[src_rank].append(
|
||||
(dst_rank, expert_id))
|
||||
expert_recv_info_this_layer[dst_rank].append(
|
||||
(src_rank, expert_id))
|
||||
|
||||
remove_index = np.where(
|
||||
np.logical_and(dst_rank_indices == dst_rank,
|
||||
experts_to_recv == expert_id))
|
||||
|
||||
# update
|
||||
dst_rank_indices = np.delete(dst_rank_indices,
|
||||
remove_index)
|
||||
experts_to_recv = np.delete(experts_to_recv,
|
||||
remove_index)
|
||||
|
||||
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer_org, layer_id)
|
||||
|
||||
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
|
||||
def compose_expert_update_info_greedy(self, updated_expert_maps,
|
||||
current_expert_maps):
|
||||
num_layers = current_expert_maps.shape[0]
|
||||
for layer_id in range(num_layers):
|
||||
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
|
||||
current_expert_maps_this_layer = current_expert_maps[layer_id]
|
||||
|
||||
expert_send_info_this_layer: dict[Any, Any] = {}
|
||||
expert_recv_info_this_layer: dict[Any, Any] = {}
|
||||
|
||||
# Guard Clause: if there is no expert weight update, avoid subsequent processing
|
||||
if torch.equal(updated_expert_maps_this_layer,
|
||||
current_expert_maps_this_layer):
|
||||
yield (expert_send_info_this_layer,
|
||||
expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer, layer_id)
|
||||
|
||||
# Parse expert_ids each rank needs to receive from other ranks
|
||||
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \
|
||||
& (updated_expert_maps_this_layer != -1))
|
||||
|
||||
# Parse expert_ids each rank needs to send to other ranks
|
||||
src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \
|
||||
& (updated_expert_maps_this_layer == -1))
|
||||
|
||||
for idx in range(len(dst_rank_indices)):
|
||||
dst_rank_id = dst_rank_indices[idx].item()
|
||||
expert_id = experts_to_recv[idx].item()
|
||||
if dst_rank_id not in expert_recv_info_this_layer:
|
||||
expert_recv_info_this_layer[dst_rank_id] = []
|
||||
|
||||
if not torch.isin(torch.tensor(expert_id),
|
||||
experts_to_send).any():
|
||||
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
|
||||
candidate_src_rank_indices = torch.where(
|
||||
current_expert_maps_this_layer[:, expert_id] != -1)[0]
|
||||
else:
|
||||
candidate_src_rank_indices = src_rank_indices[
|
||||
experts_to_send == expert_id]
|
||||
|
||||
# TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node...
|
||||
src_rank_id = candidate_src_rank_indices[0].item()
|
||||
if src_rank_id not in expert_send_info_this_layer:
|
||||
expert_send_info_this_layer[src_rank_id] = []
|
||||
|
||||
expert_send_info_this_layer[src_rank_id].append(
|
||||
(dst_rank_id, expert_id))
|
||||
expert_recv_info_this_layer[dst_rank_id].append(
|
||||
(src_rank_id, expert_id))
|
||||
|
||||
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
|
||||
updated_expert_maps_this_layer, layer_id)
|
||||
|
||||
def calculate_rebalance_experts(self, load_info, old_placement):
|
||||
"""
|
||||
Compute `new_map` by calling the `rebalance_experts` method of the policy instance.
|
||||
"""
|
||||
if self.old_expert_maps is None:
|
||||
return False, None, None
|
||||
|
||||
changed, priority, new_map = self.policy.rebalance_experts(
|
||||
old_placement, load_info)
|
||||
return changed, priority, new_map
|
||||
|
||||
def get_init_expert_maps(self):
|
||||
"""
|
||||
Read the initial expert_map from shared_dict.
|
||||
"""
|
||||
return self.shared_dict.get("expert_maps", None)
|
||||
|
||||
def fetch_and_sum_load_info(self):
|
||||
"""
|
||||
Each time the subprocess is awakened, read the latest moe_load
|
||||
(shape: [num_moe_layers, num_experts_per_layer]) from shared_dict.
|
||||
"""
|
||||
return self.shared_dict.get("moe_load", None)
|
||||
|
||||
def update_expert_map(self, expert_maps):
|
||||
|
||||
self.shared_dict["expert_maps"] = expert_maps
|
||||
|
||||
def global2local(self, placement: torch.Tensor,
|
||||
E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
L, G, _ = placement.shape
|
||||
device = placement.device
|
||||
|
||||
pt_local = torch.full((L, G, E_local),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement >= 0
|
||||
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
|
||||
|
||||
slot_idx = placement[l_idx, g_idx, k_idx]
|
||||
|
||||
pt_local[l_idx, g_idx, slot_idx] = k_idx
|
||||
|
||||
return pt_local
|
||||
|
||||
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
L, G, E_local = placement_local.shape
|
||||
device = placement_local.device
|
||||
|
||||
max_id = torch.max(placement_local)
|
||||
E_global = (max_id + 1).item() if max_id >= 0 else 0
|
||||
|
||||
if E_global == 0:
|
||||
return torch.empty((L, G, 0), dtype=torch.long, device=device)
|
||||
|
||||
placement_global = torch.full((L, G, E_global),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
valid = placement_local >= 0
|
||||
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
|
||||
gid_idx = placement_local[l_idx, g_idx, slot_idx]
|
||||
|
||||
placement_global[l_idx, g_idx, gid_idx] = slot_idx
|
||||
|
||||
return placement_global
|
||||
|
||||
def pack_update_info(self, update_info_generator):
|
||||
"""
|
||||
Pack a list of update info tuples for efficient IPC.
|
||||
"""
|
||||
send_all = []
|
||||
recv_all = []
|
||||
maps = []
|
||||
log2phy_all = []
|
||||
layer_ids = []
|
||||
|
||||
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
|
||||
|
||||
send_info_this_rank = send_info[
|
||||
self.rank_id] if self.rank_id in send_info else []
|
||||
recv_info_this_rank = recv_info[
|
||||
self.rank_id] if self.rank_id in recv_info else []
|
||||
send_all.append(send_info_this_rank)
|
||||
recv_all.append(recv_info_this_rank)
|
||||
|
||||
maps.append(new_expert_map[self.rank_id].numpy().tolist())
|
||||
|
||||
log2phy_map = generate_log2phy_map(new_expert_map)
|
||||
log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist())
|
||||
|
||||
layer_ids.append(layer_id)
|
||||
|
||||
return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids))
|
||||
|
||||
|
||||
class EplbProcess:
|
||||
|
||||
def __init__(self,
|
||||
shared_dict,
|
||||
policy_type: int = 0,
|
||||
enable_d2d: bool = True):
|
||||
"""
|
||||
Args:
|
||||
shared_dict: Cross-process shared dict returned by Manager().dict()
|
||||
policy_type: Integer passed to PolicyFactory.generate_policy
|
||||
enable_d2d: Whether to enable D2D loading
|
||||
"""
|
||||
self.shared_dict = shared_dict
|
||||
self.policy_type = policy_type
|
||||
self.enable_d2d = enable_d2d
|
||||
self.planner_q: Queue[Any] = Queue()
|
||||
self.block_update_q: Queue[Any] = Queue(maxsize=1)
|
||||
|
||||
# Create EplbWorker instance
|
||||
self.worker = EplbWorker(self.shared_dict, self.policy_type,
|
||||
self.enable_d2d)
|
||||
|
||||
def worker_process(self, planner_q, block_update_q):
|
||||
"""
|
||||
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
planner_q.get()
|
||||
|
||||
packed_update_info = self.worker.do_update()
|
||||
|
||||
while True:
|
||||
if not block_update_q.empty():
|
||||
continue
|
||||
block_update_q.put(packed_update_info)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[EPLB subprocess Exiting due to error: {e}",
|
||||
exc_info=True)
|
||||
break
|
||||
|
||||
def _launch_process(self):
|
||||
"""
|
||||
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
|
||||
"""
|
||||
proc = Process(target=self.worker_process,
|
||||
args=(self.planner_q, self.block_update_q),
|
||||
daemon=True)
|
||||
|
||||
proc.start()
|
||||
return proc
|
||||
0
vllm_npu/eplb/core/policy/__init__.py
Normal file
0
vllm_npu/eplb/core/policy/__init__.py
Normal file
42
vllm_npu/eplb/core/policy/policy_abstract.py
Normal file
42
vllm_npu/eplb/core/policy/policy_abstract.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class DynamicConfig:
|
||||
placement_policy = None
|
||||
|
||||
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
|
||||
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
|
||||
num_die_per_host = 8 # Number of dies on each host machine
|
||||
|
||||
|
||||
class EplbPolicy:
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
"""
|
||||
Pass in the weights and return expert replication and placement under relevant constraints.
|
||||
INPUT:
|
||||
current_expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_workload = expert_table[layer0][rankId][expert_num_i]
|
||||
|
||||
RETURNED: (res, expert_table)
|
||||
res:
|
||||
1 -- table_changed
|
||||
0 -- not_changed
|
||||
|
||||
expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_num_i --- [0, MaxExpertPerRank]
|
||||
expertID = expert_table[layer0][rankId][expert_num_i]
|
||||
array_values:
|
||||
[0, 1, 2, 3, 248]
|
||||
[4, 5, 6, 7, 254]
|
||||
[8, 9, 10, 11, 71]
|
||||
...
|
||||
[252, 253, 254, 255, 0]
|
||||
"""
|
||||
pass
|
||||
389
vllm_npu/eplb/core/policy/policy_dynamic_ep.py
Normal file
389
vllm_npu/eplb/core/policy/policy_dynamic_ep.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
|
||||
|
||||
class DynamicTable:
|
||||
# workload_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
workload_table = None
|
||||
|
||||
# placement_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
placement_table = None
|
||||
|
||||
|
||||
class DynamicEplb(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@staticmethod
|
||||
def add_redundant(current_expert_table, expert_workload,
|
||||
num_original_expert):
|
||||
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
||||
workload_new = np.zeros((layer_num, num_original_expert))
|
||||
for layer_idx in range(layer_num):
|
||||
workload_dict: dict[int, int] = defaultdict(int)
|
||||
placement_layer = current_expert_table[layer_idx].copy()
|
||||
workload_layer = expert_workload[layer_idx].copy()
|
||||
for npu_idx in range(npu_num):
|
||||
for expert_idx in range(experts_per_npu):
|
||||
workload_dict[placement_layer[npu_idx][
|
||||
expert_idx]] += workload_layer[npu_idx][expert_idx]
|
||||
for expert_idx in range(num_original_expert):
|
||||
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
||||
return workload_new
|
||||
|
||||
@staticmethod
|
||||
# Split hot (high-load) experts into redundant experts
|
||||
def original_compute_balanced_pack_redundancy(origin_weights, card_num,
|
||||
num_redundancy_expert):
|
||||
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
|
||||
# Sort based on the second element (the second value of each tuple)
|
||||
route_expert_num = len(origin_weights)
|
||||
route_expert_redundancy: list[list[int]] = [
|
||||
[] for _ in range(route_expert_num)
|
||||
]
|
||||
for i in range(num_redundancy_expert):
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
||||
kind='stable')[::-1]
|
||||
weights = [origin_weights[idx] for idx in sorted_indices]
|
||||
tmp_raw_weight = weights[0][1] * (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
||||
avg_weight = tmp_raw_weight / (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
weights[0] = (weights[0][0], avg_weight)
|
||||
origin_weights = weights
|
||||
|
||||
# Step 2: Calculate the number of items per box
|
||||
expert_num = route_expert_num + num_redundancy_expert
|
||||
items_per_box = expert_num // card_num # Number of items per box
|
||||
remaining_items = expert_num % card_num # Number of items per box
|
||||
|
||||
# Step 3: Initialize card_num boxes with empty lists to store item IDs
|
||||
boxes: list[list[int]] = [[] for _ in range(card_num)]
|
||||
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
|
||||
box_weights = [0] * card_num # To store the total weight of each box
|
||||
box_counts = [0] * card_num # To store the number of items in each box
|
||||
index = 0
|
||||
for i in range(route_expert_num):
|
||||
redundancy_num = len(route_expert_redundancy[i])
|
||||
for _ in range(redundancy_num):
|
||||
cur_weight = 0
|
||||
for item, weight in origin_weights:
|
||||
if item == i:
|
||||
cur_weight = weight
|
||||
|
||||
boxes[index].append(i)
|
||||
boxes_weights[index].append(cur_weight)
|
||||
box_weights[index] += cur_weight
|
||||
box_counts[index] += 1
|
||||
index += 1
|
||||
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
||||
kind='stable')[::-1]
|
||||
origin_weights = [origin_weights[idx] for idx in sorted_indices]
|
||||
# Step 4: Distribute items into boxes based on weight
|
||||
for item_id, weight in origin_weights:
|
||||
# Find the box with the least items but not full
|
||||
min_box_index = -1
|
||||
for i in range(card_num):
|
||||
if item_id in boxes[i]:
|
||||
continue
|
||||
# Only choose boxes that still have space (box_counts[i] < items_per_box)
|
||||
if box_counts[i] < items_per_box or (box_counts[i]
|
||||
== items_per_box
|
||||
and remaining_items > 0):
|
||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
||||
min_box_index]:
|
||||
min_box_index = i
|
||||
|
||||
# Place the item (id) into the selected box
|
||||
boxes[min_box_index].append(item_id)
|
||||
boxes_weights[min_box_index].append(weight)
|
||||
box_weights[min_box_index] += weight
|
||||
box_counts[min_box_index] += 1
|
||||
|
||||
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter
|
||||
if box_counts[min_box_index] == (items_per_box +
|
||||
1) and remaining_items > 0:
|
||||
remaining_items -= 1
|
||||
|
||||
# Step 5: Output each box's contents and total weight
|
||||
result = []
|
||||
for i in range(card_num):
|
||||
result.append({
|
||||
"box_index": i + 1,
|
||||
"items": boxes[i], # List of item IDs in the box
|
||||
"weight": boxes_weights[i],
|
||||
"total_weight": box_weights[i], # Total weight in this box
|
||||
"item_count": box_counts[i] # Number of items in the box
|
||||
})
|
||||
|
||||
return result, boxes
|
||||
|
||||
# Split hot (high-load) experts into redundant experts
|
||||
@staticmethod
|
||||
def compute_balanced_pack_redundancy(origin_weights, card_num,
|
||||
num_redundancy_expert):
|
||||
route_expert_num = len(origin_weights)
|
||||
route_expert_redundancy: list[list[int]] = [
|
||||
[] for _ in range(route_expert_num)
|
||||
]
|
||||
for i in range(num_redundancy_expert):
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights],
|
||||
kind='stable')[::-1]
|
||||
weights = [origin_weights[idx] for idx in sorted_indices]
|
||||
tmp_raw_weight = weights[0][1] * (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
|
||||
avg_weight = tmp_raw_weight / (
|
||||
len(route_expert_redundancy[weights[0][0]]) + 1)
|
||||
weights[0] = (weights[0][0], avg_weight)
|
||||
origin_weights = weights
|
||||
|
||||
expert_num = route_expert_num + num_redundancy_expert
|
||||
if card_num == 0:
|
||||
raise RuntimeError("card_num can not be 0.")
|
||||
items_per_box = expert_num // card_num
|
||||
remaining_items = expert_num % card_num
|
||||
|
||||
boxes: list[list[int]] = [[] for _ in range(card_num)]
|
||||
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
|
||||
box_weights = [0] * card_num
|
||||
box_counts = [0] * card_num
|
||||
|
||||
all_weights = np.zeros((expert_num, ), dtype='object')
|
||||
all_weights[:route_expert_num] = origin_weights
|
||||
|
||||
index = route_expert_num
|
||||
for i in range(route_expert_num):
|
||||
redundancy_num = len(route_expert_redundancy[i])
|
||||
for _ in range(redundancy_num):
|
||||
for item, weight in origin_weights:
|
||||
if item == i:
|
||||
all_weights[index] = (item, weight)
|
||||
index += 1
|
||||
|
||||
sorted_indices = np.argsort([t[1] for t in all_weights],
|
||||
kind='stable')[::-1]
|
||||
all_weights = [all_weights[idx] for idx in sorted_indices]
|
||||
for item_id, weight in all_weights:
|
||||
min_box_index = -1
|
||||
for i in range(card_num):
|
||||
if box_counts[i] < items_per_box or (box_counts[i]
|
||||
== items_per_box
|
||||
and remaining_items > 0):
|
||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
||||
min_box_index]:
|
||||
if item_id not in boxes[i]:
|
||||
min_box_index = i
|
||||
|
||||
boxes[min_box_index].append(item_id)
|
||||
boxes_weights[min_box_index].append(weight)
|
||||
box_weights[min_box_index] += weight
|
||||
box_counts[min_box_index] += 1
|
||||
|
||||
if box_counts[min_box_index] == (items_per_box +
|
||||
1) and remaining_items > 0:
|
||||
remaining_items -= 1
|
||||
|
||||
result = []
|
||||
for i in range(card_num):
|
||||
result.append({
|
||||
"box_index": i + 1,
|
||||
"items": boxes[i],
|
||||
"weight": boxes_weights[i],
|
||||
"total_weight": box_weights[i],
|
||||
"item_count": box_counts[i]
|
||||
})
|
||||
|
||||
return result, boxes
|
||||
|
||||
# Scheme without redundant experts
|
||||
@staticmethod
|
||||
def compute_balanced_pack(origin_weights, card_num):
|
||||
sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1]
|
||||
weights = origin_weights[sorted_indices]
|
||||
expert_num = len(weights)
|
||||
if card_num == 0:
|
||||
raise RuntimeError("card_num can not be 0.")
|
||||
items_per_box = expert_num // card_num
|
||||
remaining_items = expert_num % card_num
|
||||
|
||||
boxes: list[list[int]] = [[] for _ in range(card_num)]
|
||||
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
|
||||
box_weights = [0] * card_num
|
||||
box_counts = [0] * card_num
|
||||
|
||||
for item_id, weight in weights:
|
||||
min_box_index = -1
|
||||
for i in range(card_num):
|
||||
if box_counts[i] < items_per_box or (box_counts[i]
|
||||
== items_per_box
|
||||
and remaining_items > 0):
|
||||
if min_box_index == -1 or box_weights[i] < box_weights[
|
||||
min_box_index]:
|
||||
min_box_index = i
|
||||
|
||||
boxes[min_box_index].append(item_id)
|
||||
boxes_weights[min_box_index].append(weight)
|
||||
box_weights[min_box_index] += weight
|
||||
box_counts[min_box_index] += 1
|
||||
|
||||
if box_counts[min_box_index] == (items_per_box +
|
||||
1) and remaining_items > 0:
|
||||
remaining_items -= 1
|
||||
|
||||
result = []
|
||||
for i in range(card_num):
|
||||
result.append({
|
||||
"box_index": i + 1,
|
||||
"items": boxes[i],
|
||||
"weight": boxes_weights[i],
|
||||
"total_weight": box_weights[i],
|
||||
"item_count": box_counts[i]
|
||||
})
|
||||
|
||||
return result, boxes
|
||||
|
||||
@staticmethod
|
||||
def get_redundant_num(npu_num, counts):
|
||||
redundant_num_each_npu: int = np.sum(counts - 1)
|
||||
return redundant_num_each_npu
|
||||
|
||||
@staticmethod
|
||||
def calculate_max_heat_per_layer(workload_table, layer_num):
|
||||
max_heat_per_layer: list[float] = []
|
||||
for layer_idx in range(layer_num):
|
||||
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
|
||||
max_heat_per_layer.append(np.max(npu_heats_now))
|
||||
return max_heat_per_layer
|
||||
|
||||
@staticmethod
|
||||
def constraint_expert_local_exchange(current_expert_table,
|
||||
global_deployment):
|
||||
for layer_id in range(len(global_deployment)):
|
||||
for card_id in range(len(global_deployment[layer_id])):
|
||||
current_list = [
|
||||
int(x) for x in current_expert_table[layer_id][card_id]
|
||||
]
|
||||
new_list = [
|
||||
int(x) for x in global_deployment[layer_id][card_id]
|
||||
]
|
||||
num = len(new_list)
|
||||
|
||||
new_index = [-1] * num
|
||||
new_result = [-1] * num
|
||||
remaining_elements = []
|
||||
|
||||
for i in range(num):
|
||||
flag = True
|
||||
for j in range(num):
|
||||
if new_list[i] == current_list[j] and new_index[
|
||||
j] == -1:
|
||||
new_index[j] = 0
|
||||
new_result[j] = current_list[j]
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
remaining_elements.append(new_list[i])
|
||||
|
||||
index = 0
|
||||
for k in range(num):
|
||||
if new_result[k] == -1:
|
||||
new_result[k] = remaining_elements[index]
|
||||
index += 1
|
||||
|
||||
global_deployment[layer_id][card_id] = new_result
|
||||
|
||||
return global_deployment
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
|
||||
info = DynamicTable()
|
||||
info.workload_table = np.array(expert_workload)
|
||||
info.placement_table = np.array(current_expert_table)
|
||||
assert info.workload_table is not None
|
||||
layer_num, num_npus, experts_per_npu = info.workload_table.shape
|
||||
assert info.placement_table is not None
|
||||
row = cast(np.ndarray, info.placement_table[0])
|
||||
expert_ids, counts = np.unique(row, return_counts=True)
|
||||
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
||||
num_original_expert = len(expert_ids)
|
||||
layer_workloads = self.add_redundant(info.placement_table,
|
||||
info.workload_table,
|
||||
num_original_expert)
|
||||
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
|
||||
info.workload_table, layer_num)
|
||||
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
||||
|
||||
# Perform load balancing and deploy redundant experts
|
||||
layer_num = layer_workloads.shape[0]
|
||||
expert_num = layer_workloads.shape[1]
|
||||
# Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards
|
||||
if num_original_expert != expert_num:
|
||||
raise ValueError(
|
||||
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
|
||||
)
|
||||
|
||||
if num_npus <= 0:
|
||||
raise ValueError("the number of NPUs must be greater than 0")
|
||||
|
||||
if num_npus < num_redundancy_expert:
|
||||
raise ValueError(
|
||||
f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}"
|
||||
)
|
||||
|
||||
# Number of experts deployed on each card includes one redundant expert
|
||||
global_deployment: list[list[list[int]]] = [[[]
|
||||
for _ in range(num_npus)]
|
||||
for _ in range(layer_num)]
|
||||
# Iterate to obtain the placement strategy for each layer, taking computational balance into account
|
||||
max_heat_per_layer_after = np.zeros([layer_num])
|
||||
for layer in range(layer_num):
|
||||
# Get the expert IDs and their corresponding workloads for the current layer;
|
||||
# workloads need to be normalized, and one redundant expert is added per card
|
||||
weights = np.zeros((expert_num, ), dtype='object')
|
||||
for expert_id, workload_weight in enumerate(
|
||||
layer_workloads[layer]):
|
||||
weights[expert_id] = (expert_id, workload_weight)
|
||||
|
||||
# Obtain the globally balanced placement strategy for each layer
|
||||
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
|
||||
weights, num_npus, num_redundancy_expert)
|
||||
|
||||
global_deployment[layer] = layer_deployment
|
||||
max_heat_per_layer_after[layer] = max(
|
||||
result, key=lambda x: x['total_weight'])['total_weight']
|
||||
|
||||
new_global_deployment = self.constraint_expert_local_exchange(
|
||||
current_expert_table, global_deployment)
|
||||
# Obtain the priority of each layer
|
||||
layer_changed_ratio = []
|
||||
for layer_idx in range(layer_num):
|
||||
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] /
|
||||
max_heat_per_layer_before[layer_idx])
|
||||
|
||||
per_layer_priority = np.argsort(layer_changed_ratio)
|
||||
npu_heat_all_after = sum(max_heat_per_layer_after)
|
||||
|
||||
change = 0
|
||||
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
||||
change = 1
|
||||
|
||||
return change, per_layer_priority, np.array(
|
||||
new_global_deployment).tolist()
|
||||
771
vllm_npu/eplb/core/policy/policy_dynamic_ep_v2.py
Normal file
771
vllm_npu/eplb/core/policy/policy_dynamic_ep_v2.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DynamicConfig:
|
||||
placement_policy = None
|
||||
|
||||
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
|
||||
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
|
||||
num_die_per_host = 8 # Number of dies on each host machine
|
||||
|
||||
|
||||
class EplbPolicy:
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
"""
|
||||
Pass in the weights and return expert replication and placement under relevant constraints.
|
||||
INPUT:
|
||||
current_expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_workload = expert_table[layer0][rankId][expert_num_i]
|
||||
|
||||
RETURNED: (res, expert_table)
|
||||
res:
|
||||
1 -- table_changed
|
||||
0 -- not_changed
|
||||
|
||||
expert_table: [layerId, rankId, expert_num_i]
|
||||
expert_num_i --- [0, MaxExpertPerRank]
|
||||
expertID = expert_table[layer0][rankId][expert_num_i]
|
||||
array_values:
|
||||
[0, 1, 2, 3, 248]
|
||||
[4, 5, 6, 7, 254]
|
||||
[8, 9, 10, 11, 71]
|
||||
...
|
||||
[252, 253, 254, 255, 0]
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DynamicTable:
|
||||
# workload_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
workload_table = None
|
||||
|
||||
# placement_table:
|
||||
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
|
||||
# Size: number of layers * number of GPUs * number of experts per GPU per layer
|
||||
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
|
||||
# For experts that are not available or collected, the value is set to -1
|
||||
placement_table = None
|
||||
|
||||
|
||||
class DynamicEplbV2(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@staticmethod
|
||||
def safe_divide(a, b):
|
||||
if b == 0:
|
||||
print("Division by zero is not allowed")
|
||||
return 0
|
||||
return a / b
|
||||
|
||||
@staticmethod
|
||||
def safe_exact_divide(a, b):
|
||||
if b == 0:
|
||||
print("Division by zero is not allowed")
|
||||
return 0
|
||||
return a // b
|
||||
|
||||
@staticmethod
|
||||
def safe_mod(a, b):
|
||||
if b == 0:
|
||||
print("Division by zero is not allowed")
|
||||
return 0
|
||||
return a % b
|
||||
|
||||
@staticmethod
|
||||
def add_redundant(current_expert_table, expert_workload,
|
||||
num_original_expert):
|
||||
layer_num, npu_num, experts_per_npu = expert_workload.shape
|
||||
workload_new = np.zeros((layer_num, num_original_expert))
|
||||
for layer_idx in range(layer_num):
|
||||
workload_dict: dict[int, int] = defaultdict(int)
|
||||
placement_layer = current_expert_table[layer_idx].copy()
|
||||
workload_layer = expert_workload[layer_idx].copy()
|
||||
for npu_idx in range(npu_num):
|
||||
for expert_idx in range(experts_per_npu):
|
||||
workload_dict[placement_layer[npu_idx][
|
||||
expert_idx]] += workload_layer[npu_idx][expert_idx]
|
||||
for expert_idx in range(num_original_expert):
|
||||
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
|
||||
return workload_new
|
||||
|
||||
@staticmethod
|
||||
def get_redundant_num(npu_num, counts):
|
||||
redundant_num_each_npu: int = int(np.sum(counts - 1))
|
||||
return redundant_num_each_npu
|
||||
|
||||
@staticmethod
|
||||
def calculate_max_heat_per_layer(workload_table, layer_num):
|
||||
max_heat_per_layer: list[float] = []
|
||||
for layer_idx in range(layer_num):
|
||||
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
|
||||
max_heat_per_layer.append(np.max(npu_heats_now))
|
||||
return max_heat_per_layer
|
||||
|
||||
def calculate_initial_imbalance(self, global_deployment,
|
||||
new_layer_workloads):
|
||||
|
||||
device_num = global_deployment.shape[1]
|
||||
layer_imbalance = []
|
||||
expert_num = np.zeros_like(new_layer_workloads)
|
||||
for layer_id, layer in enumerate(global_deployment):
|
||||
for device in layer:
|
||||
for expert_id in device:
|
||||
expert_num[layer_id][expert_id] += 1
|
||||
|
||||
for layer_id, layer in enumerate(global_deployment):
|
||||
cur_layer_max_workload = 0
|
||||
total_workload = 0
|
||||
for box in layer:
|
||||
box_workload = 0
|
||||
for expert_id in box:
|
||||
update_workload = self.safe_divide(
|
||||
new_layer_workloads[layer_id][expert_id],
|
||||
expert_num[layer_id][expert_id])
|
||||
box_workload += update_workload
|
||||
total_workload += update_workload
|
||||
if cur_layer_max_workload < box_workload:
|
||||
cur_layer_max_workload = box_workload
|
||||
|
||||
cur_layer_imbalance = self.safe_divide(
|
||||
cur_layer_max_workload,
|
||||
(self.safe_divide(total_workload, device_num)))
|
||||
layer_imbalance.append(cur_layer_imbalance)
|
||||
|
||||
return layer_imbalance
|
||||
|
||||
def compute_redundant_assignments(self, base_experts,
|
||||
num_redundant_experts, num_experts):
|
||||
|
||||
redundant_assignments: list[list[int]] = [[]
|
||||
for _ in range(num_experts)]
|
||||
current_weights = base_experts.copy()
|
||||
|
||||
for i in range(num_redundant_experts):
|
||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
||||
kind='stable')[::-1]
|
||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||
|
||||
target_expert = sorted_weights[0]
|
||||
expert_id, original_weight = target_expert
|
||||
|
||||
current_redundancy = len(redundant_assignments[expert_id])
|
||||
new_avg_weight = self.safe_divide(
|
||||
original_weight * (current_redundancy + 1),
|
||||
(current_redundancy + 2))
|
||||
|
||||
redundant_assignments[expert_id].append(num_experts + i)
|
||||
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
|
||||
|
||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
||||
kind='stable')[::-1]
|
||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||
|
||||
return redundant_assignments, sorted_weights
|
||||
|
||||
def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos,
|
||||
num_experts, num_exist_expert,
|
||||
device_assignments, device_counts,
|
||||
expert_from_device,
|
||||
com_between_devices):
|
||||
|
||||
current_weights = np.zeros((num_experts, ), dtype='object')
|
||||
for expert_id, workload_weight in enumerate(layer_workloads):
|
||||
current_weights[expert_id] = (expert_id, workload_weight)
|
||||
|
||||
devices_with_slots = []
|
||||
for device_id, device_rendun_pos in enumerate(rendun_pos):
|
||||
if len(device_rendun_pos) != 0:
|
||||
devices_with_slots.append(device_id)
|
||||
|
||||
while devices_with_slots:
|
||||
sorted_indices = np.argsort([w for _, w in current_weights],
|
||||
kind='stable')[::-1]
|
||||
sorted_weights = [current_weights[i] for i in sorted_indices]
|
||||
|
||||
for index, target_weight in enumerate(sorted_weights):
|
||||
expert_id, original_weight = target_weight
|
||||
if original_weight == -1:
|
||||
print("Error:Redundant expert failure re-occurred")
|
||||
redundancy_successful = True
|
||||
break
|
||||
redundancy_successful = False
|
||||
for cur_device_id in devices_with_slots:
|
||||
if expert_id not in device_assignments[cur_device_id]:
|
||||
pos = rendun_pos[cur_device_id].pop()
|
||||
if len(rendun_pos[cur_device_id]) == 0:
|
||||
devices_with_slots = [
|
||||
device_id for device_id in devices_with_slots
|
||||
if device_id != cur_device_id
|
||||
]
|
||||
device_assignments[cur_device_id][pos] = expert_id
|
||||
device_counts[cur_device_id] += 1
|
||||
communication_box_index = expert_from_device[expert_id]
|
||||
com_between_devices[cur_device_id][
|
||||
communication_box_index] = expert_id
|
||||
new_weight = self.safe_divide(
|
||||
(original_weight * num_exist_expert[expert_id]),
|
||||
(num_exist_expert[expert_id] + 1))
|
||||
sorted_weights[index] = (expert_id, new_weight)
|
||||
num_exist_expert[expert_id] += 1
|
||||
redundancy_successful = True
|
||||
break
|
||||
if redundancy_successful:
|
||||
break
|
||||
|
||||
sorted_indices = np.argsort([id for id, _ in sorted_weights],
|
||||
kind='stable')
|
||||
sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
|
||||
|
||||
return sorted_weights, device_assignments, device_counts, com_between_devices
|
||||
|
||||
@staticmethod
|
||||
def prepare_expert_list(base_experts, redundant_assignments,
|
||||
num_redundant_experts):
|
||||
redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
|
||||
|
||||
index = 0
|
||||
num_experts = len(redundant_assignments)
|
||||
for expert_id in range(num_experts):
|
||||
for _ in redundant_assignments[expert_id]:
|
||||
redundant_expert_list[index] = (expert_id,
|
||||
next(w
|
||||
for eid, w in base_experts
|
||||
if eid == expert_id))
|
||||
index += 1
|
||||
|
||||
sorted_indices = np.argsort([w for _, w in redundant_expert_list],
|
||||
kind='stable')[::-1]
|
||||
return [redundant_expert_list[i] for i in sorted_indices]
|
||||
|
||||
@staticmethod
|
||||
def non_redundant_expert_information(origin_deployment, updated_weights,
|
||||
rendun_pos):
|
||||
|
||||
device_num = len(origin_deployment)
|
||||
num_experts_per_device = origin_deployment.shape[1]
|
||||
device_assignments = [[-1 for _ in range(num_experts_per_device)]
|
||||
for _ in range(device_num)]
|
||||
device_weights = [[0 for _ in range(num_experts_per_device)]
|
||||
for _ in range(device_num)]
|
||||
device_loads = [0] * device_num
|
||||
device_counts = [0] * device_num
|
||||
|
||||
for device_id, device in enumerate(origin_deployment):
|
||||
for index, expert_id in enumerate(device):
|
||||
if index in rendun_pos[device_id]:
|
||||
continue
|
||||
device_assignments[device_id][index] = expert_id
|
||||
cur_weight = next(
|
||||
weight for expert_id_of_weight, weight in updated_weights
|
||||
if expert_id_of_weight == expert_id)
|
||||
device_weights[device_id][index] = cur_weight
|
||||
device_loads[device_id] += cur_weight
|
||||
device_counts[device_id] += 1
|
||||
|
||||
return device_assignments, device_weights, device_loads, device_counts
|
||||
|
||||
def recomputing_initial_weight(self, layer_workloads, device_assignments):
|
||||
num_all_experts = [0] * len(layer_workloads)
|
||||
for device in device_assignments:
|
||||
for expert_id in device:
|
||||
if expert_id != -1:
|
||||
num_all_experts[expert_id] += 1
|
||||
|
||||
cur_layer_workload = []
|
||||
for expert_id, weight in enumerate(layer_workloads):
|
||||
if num_all_experts[expert_id] == 0:
|
||||
cur_layer_workload.append(-1)
|
||||
else:
|
||||
cur_layer_workload.append(
|
||||
self.safe_divide(weight, num_all_experts[expert_id]))
|
||||
|
||||
return cur_layer_workload, num_all_experts
|
||||
|
||||
def distribute_redun_experts(self, layer_workloads, device_assignments,
|
||||
device_weights, device_loads, device_counts,
|
||||
redundant_expert_list, expert_from_device,
|
||||
num_experts, rendun_pos):
|
||||
|
||||
num_devices = len(device_assignments)
|
||||
com_between_devices: list[dict[int,
|
||||
int]] = [{} for _ in range(num_devices)]
|
||||
|
||||
for expert_id, weight in redundant_expert_list:
|
||||
candidate = -1
|
||||
for dev_id in range(num_devices):
|
||||
if len(rendun_pos[dev_id]) == 0:
|
||||
continue
|
||||
if expert_id in device_assignments[dev_id]:
|
||||
continue
|
||||
if candidate == -1 or device_loads[dev_id] < device_loads[
|
||||
candidate]:
|
||||
candidate = dev_id
|
||||
if candidate != -1:
|
||||
pos = rendun_pos[candidate].pop()
|
||||
device_assignments[candidate][pos] = expert_id
|
||||
device_weights[candidate][pos] = weight
|
||||
device_loads[candidate] += weight
|
||||
device_counts[candidate] += 1
|
||||
|
||||
communication_box_index = expert_from_device[expert_id]
|
||||
com_between_devices[candidate][
|
||||
communication_box_index] = expert_id
|
||||
|
||||
if any(sublist for sublist in rendun_pos):
|
||||
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(
|
||||
layer_workloads, device_assignments)
|
||||
|
||||
update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments(
|
||||
cur_layer_workload, rendun_pos, num_experts, num_exist_expert,
|
||||
device_assignments, device_loads, expert_from_device,
|
||||
com_between_devices)
|
||||
|
||||
device_loads = [0] * len(device_counts)
|
||||
for device_id, device in enumerate(device_assignments):
|
||||
for index, expert_id in enumerate(device):
|
||||
device_weights[device_id][index] = update_workload[
|
||||
expert_id]
|
||||
device_loads[device_id] += update_workload[expert_id]
|
||||
|
||||
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
||||
|
||||
def redundancy_again(self, layer_workloads, origin_weights,
|
||||
origin_deployment, expert_from_device, num_node,
|
||||
is_node_redundant, rendun_pos):
|
||||
|
||||
num_experts = len(origin_weights)
|
||||
if is_node_redundant:
|
||||
num_experts = num_experts * num_node
|
||||
|
||||
num_redundant_experts = 0
|
||||
for rank_empty_pos in rendun_pos:
|
||||
num_redundant_experts += len(rank_empty_pos)
|
||||
|
||||
redundant_assignments, updated_weights = self.compute_redundant_assignments(
|
||||
origin_weights, num_redundant_experts, num_experts)
|
||||
|
||||
redundant_expert_list = self.prepare_expert_list(
|
||||
updated_weights, redundant_assignments, num_redundant_experts)
|
||||
|
||||
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
|
||||
origin_deployment, updated_weights, rendun_pos)
|
||||
|
||||
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts(
|
||||
layer_workloads, device_assignments, device_weights, device_loads,
|
||||
device_counts, redundant_expert_list, expert_from_device,
|
||||
num_experts, rendun_pos)
|
||||
|
||||
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
|
||||
|
||||
@staticmethod
|
||||
def generate_allocation_report(device_assignments, device_weights,
|
||||
device_loads, device_counts):
|
||||
|
||||
report = []
|
||||
max_load = 0.0
|
||||
|
||||
for dev_id in range(len(device_assignments)):
|
||||
current_load = device_loads[dev_id]
|
||||
max_load = max(max_load, current_load)
|
||||
|
||||
report.append({
|
||||
"device_id": dev_id + 1,
|
||||
"assigned_experts": device_assignments[dev_id],
|
||||
"expert_weights": device_weights[dev_id],
|
||||
"total_load": current_load,
|
||||
"expert_count": device_counts[dev_id]
|
||||
})
|
||||
|
||||
return report, max_load
|
||||
|
||||
@staticmethod
|
||||
def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id,
|
||||
next_device_id, cur_layer_result, com_between_devices):
|
||||
|
||||
cur_device_deployment = cur_layer_result[cur_device_id][
|
||||
'assigned_experts']
|
||||
next_device_deployment = cur_layer_result[next_device_id][
|
||||
'assigned_experts']
|
||||
|
||||
cur_device_weight = cur_layer_result[cur_device_id]['expert_weights']
|
||||
next_device_weight = cur_layer_result[next_device_id]['expert_weights']
|
||||
|
||||
cur_expert_id = cur_device_deployment[cur_exchange_index]
|
||||
next_expert_id = next_device_deployment[next_exchange_index]
|
||||
cur_device_deployment[cur_exchange_index] = next_expert_id
|
||||
next_device_deployment[next_exchange_index] = cur_expert_id
|
||||
|
||||
cur_expert_weight = cur_device_weight[cur_exchange_index]
|
||||
next_expert_weight = next_device_weight[next_exchange_index]
|
||||
cur_device_weight[cur_exchange_index] = next_expert_weight
|
||||
next_device_weight[next_exchange_index] = cur_expert_weight
|
||||
|
||||
cur_layer_result[cur_device_id][
|
||||
'total_load'] += next_expert_weight - cur_expert_weight
|
||||
cur_layer_result[next_device_id][
|
||||
'total_load'] += cur_expert_weight - next_expert_weight
|
||||
|
||||
com_between_devices[cur_device_id][next_device_id] = next_expert_id
|
||||
com_between_devices[next_device_id][cur_device_id] = cur_expert_id
|
||||
|
||||
def redundant_expert_deployment(self, layer_workloads, original_deployment,
|
||||
expert_from_device, node_num,
|
||||
is_node_redundant, rendun_pos):
|
||||
device_num, per_device_expert_num = original_deployment.shape
|
||||
route_expert_num = layer_workloads.shape[0]
|
||||
per_node_device_num = self.safe_exact_divide(device_num, node_num)
|
||||
per_node_route_expert_num = per_node_device_num * (
|
||||
per_device_expert_num - 1)
|
||||
|
||||
weights = np.zeros((route_expert_num, ), dtype='object')
|
||||
for expert_id, workload_weight in enumerate(layer_workloads):
|
||||
weights[expert_id] = (expert_id, workload_weight)
|
||||
|
||||
if is_node_redundant:
|
||||
|
||||
device_assignments = []
|
||||
device_weights = []
|
||||
device_loads = []
|
||||
device_counts = []
|
||||
com_between_devices = []
|
||||
|
||||
for node_id in range(node_num):
|
||||
cur_node_weights = weights[node_id *
|
||||
per_node_route_expert_num:(node_id +
|
||||
1) *
|
||||
per_node_route_expert_num]
|
||||
cur_original_deployment = original_deployment[
|
||||
node_id * per_node_device_num:(node_id + 1) *
|
||||
per_node_device_num]
|
||||
|
||||
cur_node_rendun_pos = rendun_pos[node_id *
|
||||
per_node_device_num:(node_id +
|
||||
1) *
|
||||
per_node_device_num]
|
||||
|
||||
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again(
|
||||
layer_workloads, cur_node_weights, cur_original_deployment,
|
||||
expert_from_device, node_num, is_node_redundant,
|
||||
cur_node_rendun_pos)
|
||||
device_assignments += cur_device_assignments
|
||||
device_weights += cur_device_weights
|
||||
device_loads += cur_device_loads
|
||||
device_counts += cur_device_counts
|
||||
com_between_devices += cur_com_between_devices
|
||||
|
||||
else:
|
||||
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again(
|
||||
layer_workloads, weights, original_deployment,
|
||||
expert_from_device, node_num, is_node_redundant, rendun_pos)
|
||||
report, max_load = self.generate_allocation_report(
|
||||
device_assignments, device_weights, device_loads, device_counts)
|
||||
|
||||
return report, max_load, com_between_devices
|
||||
|
||||
@staticmethod
|
||||
def two_device_exchange_experts(cur_device_result, exchange_device_result,
|
||||
cur_exchanged_expert_id,
|
||||
next_exchanged_expert_id, ave_workload,
|
||||
increment, num_redundancy_expert):
|
||||
|
||||
cur_device_weight = cur_device_result['expert_weights']
|
||||
next_device_weight = exchange_device_result['expert_weights']
|
||||
|
||||
cur_device_expert_id = cur_device_result['assigned_experts']
|
||||
next_device_expert_id = exchange_device_result['assigned_experts']
|
||||
|
||||
cur_device_total_weight = cur_device_result['total_load']
|
||||
next_device_total_weight = exchange_device_result['total_load']
|
||||
max_weight = max(cur_device_total_weight, next_device_total_weight)
|
||||
|
||||
cur_exchange_index = -1
|
||||
next_exchange_index = -1
|
||||
|
||||
for index, weight in enumerate(cur_device_weight):
|
||||
for next_index, next_weight in enumerate(next_device_weight):
|
||||
change_flag = True
|
||||
if (cur_device_expert_id[index] in next_device_expert_id
|
||||
or next_device_expert_id[next_index]
|
||||
in cur_device_expert_id):
|
||||
change_flag = False
|
||||
if (cur_device_expert_id[index] not in cur_exchanged_expert_id
|
||||
) and (next_device_expert_id[next_index]
|
||||
not in next_exchanged_expert_id) and change_flag:
|
||||
|
||||
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
|
||||
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
|
||||
exchange_max_weight = max(
|
||||
cur_total_weight_after_exchange,
|
||||
next_total_weight_after_exchange)
|
||||
if exchange_max_weight < max_weight and (
|
||||
max_weight -
|
||||
exchange_max_weight) >= (ave_workload * increment):
|
||||
max_weight = exchange_max_weight
|
||||
cur_exchange_index = index
|
||||
next_exchange_index = next_index
|
||||
|
||||
return cur_exchange_index, next_exchange_index
|
||||
|
||||
def expert_exchange_between_devices(self,
|
||||
ave_workload,
|
||||
increment,
|
||||
cur_layer_result,
|
||||
com_between_devices,
|
||||
num_redundancy_expert,
|
||||
node_idx=0,
|
||||
per_node_device_num=0,
|
||||
is_node_redundant=False):
|
||||
|
||||
if is_node_redundant:
|
||||
cur_devices_result = cur_layer_result[node_idx *
|
||||
per_node_device_num:
|
||||
(node_idx + 1) *
|
||||
per_node_device_num]
|
||||
else:
|
||||
cur_devices_result = cur_layer_result
|
||||
|
||||
devices_total_weight = []
|
||||
for device in cur_devices_result:
|
||||
devices_total_weight.append(
|
||||
(device['total_load'], device['device_id'] - 1))
|
||||
|
||||
exchange_frequency = 100
|
||||
while exchange_frequency > 0:
|
||||
exchange_frequency -= 1
|
||||
devices_total_weight.sort(key=lambda x: x[0])
|
||||
max_weight_device_id = devices_total_weight[-1][1]
|
||||
exchange = False
|
||||
for index in range(0, len(devices_total_weight) - 1):
|
||||
min_weight_device_id = devices_total_weight[index][1]
|
||||
if min_weight_device_id not in com_between_devices[
|
||||
max_weight_device_id]:
|
||||
cur_exchanged_expert_id = list(
|
||||
com_between_devices[max_weight_device_id].values())
|
||||
next_exchanged_expert_id = list(
|
||||
com_between_devices[min_weight_device_id].values())
|
||||
|
||||
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
|
||||
cur_layer_result[max_weight_device_id],
|
||||
cur_layer_result[min_weight_device_id],
|
||||
cur_exchanged_expert_id, next_exchanged_expert_id,
|
||||
ave_workload, increment, num_redundancy_expert)
|
||||
|
||||
if cur_exchange_index != -1:
|
||||
self.exchange_expert(cur_exchange_index,
|
||||
next_exchange_index,
|
||||
max_weight_device_id,
|
||||
min_weight_device_id,
|
||||
cur_layer_result,
|
||||
com_between_devices)
|
||||
|
||||
devices_total_weight[-1] = (
|
||||
cur_layer_result[max_weight_device_id]
|
||||
['total_load'], max_weight_device_id)
|
||||
devices_total_weight[index] = (
|
||||
cur_layer_result[min_weight_device_id]
|
||||
['total_load'], min_weight_device_id)
|
||||
exchange = True
|
||||
break
|
||||
|
||||
if not exchange:
|
||||
break
|
||||
|
||||
def exchange_experts(self, layer_result, layer_com_between_devices,
|
||||
num_nodes, device_num, is_node_redundant,
|
||||
ave_workload, increment, num_redundancy_expert,
|
||||
org_deployment):
|
||||
|
||||
global_deployment = []
|
||||
|
||||
if is_node_redundant:
|
||||
per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
|
||||
for node_idx in range(num_nodes):
|
||||
self.expert_exchange_between_devices(
|
||||
ave_workload, increment, layer_result,
|
||||
layer_com_between_devices, num_redundancy_expert, node_idx,
|
||||
per_node_device_num, is_node_redundant)
|
||||
else:
|
||||
self.expert_exchange_between_devices(ave_workload, increment,
|
||||
layer_result,
|
||||
layer_com_between_devices,
|
||||
num_redundancy_expert)
|
||||
|
||||
max_workload = 0
|
||||
for box in layer_result:
|
||||
global_deployment.append(box['assigned_experts'])
|
||||
if max_workload < box['total_load']:
|
||||
max_workload = box['total_load']
|
||||
|
||||
global_deployment = np.array(global_deployment)
|
||||
|
||||
return global_deployment, max_workload
|
||||
|
||||
def count_elements(self, lst):
|
||||
count = 0
|
||||
for item in lst:
|
||||
if isinstance(item, list):
|
||||
count += self.count_elements(item)
|
||||
else:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def constraint_expert_local_exchange(current_expert_table,
|
||||
global_deployment):
|
||||
for layer_id in range(len(global_deployment)):
|
||||
for card_id in range(len(global_deployment[layer_id])):
|
||||
current_list = [
|
||||
int(x) for x in current_expert_table[layer_id][card_id]
|
||||
]
|
||||
new_list = [
|
||||
int(x) for x in global_deployment[layer_id][card_id]
|
||||
]
|
||||
num = len(new_list)
|
||||
|
||||
new_index = [-1] * num
|
||||
new_result = [-1] * num
|
||||
remaining_elements = []
|
||||
|
||||
for i in range(num):
|
||||
flag = True
|
||||
for j in range(num):
|
||||
if new_list[i] == current_list[j] and new_index[
|
||||
j] == -1:
|
||||
new_index[j] = 0
|
||||
new_result[j] = current_list[j]
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
remaining_elements.append(new_list[i])
|
||||
|
||||
index = 0
|
||||
for k in range(num):
|
||||
if new_result[k] == -1:
|
||||
new_result[k] = remaining_elements[index]
|
||||
index += 1
|
||||
|
||||
global_deployment[layer_id][card_id] = new_result
|
||||
|
||||
return global_deployment
|
||||
|
||||
def rebalance_experts(self,
|
||||
current_expert_table,
|
||||
expert_workload,
|
||||
is_node_redundant=False,
|
||||
increment=0.01):
|
||||
info = DynamicTable()
|
||||
info.workload_table = expert_workload.numpy()
|
||||
info.placement_table = current_expert_table.numpy()
|
||||
assert info.workload_table is not None
|
||||
layer_num, num_npus, experts_per_npu = info.workload_table.shape
|
||||
expert_ids, counts = np.unique(info.placement_table[0],
|
||||
return_counts=True)
|
||||
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
|
||||
num_original_expert = len(expert_ids)
|
||||
layer_workloads = self.add_redundant(info.placement_table,
|
||||
info.workload_table,
|
||||
num_original_expert)
|
||||
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
|
||||
info.workload_table, layer_num)
|
||||
npu_heat_all_origin = sum(max_heat_per_layer_before)
|
||||
|
||||
num_node = self.safe_exact_divide(num_npus, 8)
|
||||
layer_num = layer_workloads.shape[0]
|
||||
expert_num = layer_workloads.shape[1]
|
||||
expert_from_device = np.zeros((layer_num, num_original_expert))
|
||||
|
||||
if num_original_expert != expert_num:
|
||||
raise ValueError(
|
||||
f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})"
|
||||
)
|
||||
|
||||
if num_npus <= 0:
|
||||
raise ValueError("The number of NPUs must be greater than 0")
|
||||
|
||||
if num_npus < num_redundancy_expert:
|
||||
raise ValueError(
|
||||
f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})"
|
||||
)
|
||||
|
||||
global_deployment: list[list[list[int]]] = [[[]
|
||||
for _ in range(num_npus)]
|
||||
for _ in range(layer_num)]
|
||||
layer_initial_imbalance = self.calculate_initial_imbalance(
|
||||
info.placement_table, layer_workloads)
|
||||
max_heat_per_layer_after = np.zeros([layer_num])
|
||||
sum_num = 0
|
||||
for layer in range(layer_num):
|
||||
# print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer])
|
||||
if layer_initial_imbalance[layer] < 1.01:
|
||||
global_deployment[layer] = info.placement_table[layer]
|
||||
continue
|
||||
|
||||
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]),
|
||||
num_npus)
|
||||
|
||||
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
|
||||
existing_experts = set()
|
||||
for device_id, device in enumerate(info.placement_table[layer]):
|
||||
for index, expert_id in enumerate(device):
|
||||
if expert_id not in existing_experts:
|
||||
existing_experts.add(expert_id)
|
||||
expert_from_device[layer][expert_id] = device_id
|
||||
else:
|
||||
rendun_pos[device_id].append(index)
|
||||
|
||||
result, max_workload, com_between_devices = self.redundant_expert_deployment(
|
||||
layer_workloads[layer], info.placement_table[layer],
|
||||
expert_from_device[layer], num_node, is_node_redundant,
|
||||
rendun_pos)
|
||||
# print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload))
|
||||
|
||||
global_deployment[layer], new_max_workload = self.exchange_experts(
|
||||
result, com_between_devices, num_node, num_npus,
|
||||
is_node_redundant, ave_workload, increment,
|
||||
num_redundancy_expert, info.placement_table[layer])
|
||||
# print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload))
|
||||
|
||||
for device_id in range(num_npus):
|
||||
com_between_devices[device_id] = {
|
||||
key: value
|
||||
for key, value in com_between_devices[device_id].items()
|
||||
}
|
||||
sum_num += self.count_elements(com_between_devices[device_id])
|
||||
|
||||
max_heat_per_layer_after[layer] = max(
|
||||
result, key=lambda x: x['total_load'])['total_load']
|
||||
|
||||
layer_changed_ratio = []
|
||||
for layer_idx in range(layer_num):
|
||||
layer_changed_ratio.append(
|
||||
self.safe_divide(max_heat_per_layer_after[layer_idx],
|
||||
max_heat_per_layer_before[layer_idx]))
|
||||
|
||||
per_layer_priority = np.argsort(layer_changed_ratio)
|
||||
npu_heat_all_after = sum(max_heat_per_layer_after)
|
||||
|
||||
change = 0
|
||||
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
|
||||
change = 1
|
||||
|
||||
new_global_deployment = self.constraint_expert_local_exchange(
|
||||
current_expert_table, global_deployment)
|
||||
|
||||
return change, per_layer_priority, np.array(
|
||||
new_global_deployment).tolist()
|
||||
33
vllm_npu/eplb/core/policy/policy_factory.py
Normal file
33
vllm_npu/eplb/core/policy/policy_factory.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
from .policy_dynamic_ep import DynamicEplb
|
||||
from .policy_dynamic_ep_v2 import DynamicEplbV2
|
||||
from .policy_flashlb import FlashLB
|
||||
from .policy_random import RandomLoadBalance
|
||||
|
||||
|
||||
class PolicyFactory:
|
||||
|
||||
@staticmethod
|
||||
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
|
||||
policy = {
|
||||
# Constraint applying Dynamic EPLB policy V2:
|
||||
# If there exists redundant expert:
|
||||
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
|
||||
|
||||
# Applying greedy d2d expert weight update composing
|
||||
0:
|
||||
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
|
||||
1:
|
||||
DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
|
||||
2:
|
||||
DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
|
||||
3:
|
||||
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
|
||||
}
|
||||
policy_class = policy.get(policy_type, RandomLoadBalance)
|
||||
policy_instance = policy_class(config)
|
||||
if policy_type == 3:
|
||||
policy_instance.warm_up()
|
||||
return policy_instance
|
||||
651
vllm_npu/eplb/core/policy/policy_flashlb.py
Normal file
651
vllm_npu/eplb/core/policy/policy_flashlb.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import njit # type: ignore
|
||||
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@njit
|
||||
def compute_piece_counts(X, P, stage_weights):
|
||||
n_stage, N = X.shape
|
||||
S = P - N
|
||||
pieces = np.ones(N, dtype=np.int32)
|
||||
unit = X / pieces # unit[i, j] = X[i, j] / pieces[j]
|
||||
|
||||
for _ in range(S):
|
||||
deltas = np.zeros(N, dtype=np.float32)
|
||||
for i in range(n_stage):
|
||||
# Find top1 and top2
|
||||
idx1 = -1
|
||||
idx2 = -1
|
||||
val1 = -1.0
|
||||
val2 = -1.0
|
||||
for j in range(N):
|
||||
v = unit[i, j]
|
||||
if v > val1:
|
||||
val2 = val1
|
||||
idx2 = idx1
|
||||
val1 = v
|
||||
idx1 = j
|
||||
elif v > val2:
|
||||
val2 = v
|
||||
idx2 = j
|
||||
|
||||
origin = unit[i, idx1]
|
||||
secv = unit[i, idx2]
|
||||
alt = X[i, idx1] / (pieces[idx1] + 1)
|
||||
delta = origin - (alt if alt > secv else secv)
|
||||
deltas[idx1] += delta * stage_weights[i] if np.any(
|
||||
delta) != 0 else stage_weights[i]
|
||||
|
||||
max_idx = np.argmax(deltas)
|
||||
pieces[max_idx] += 1
|
||||
for i in range(n_stage):
|
||||
unit[i, max_idx] = X[i, max_idx] / pieces[max_idx]
|
||||
|
||||
# Compute max load
|
||||
max_load = 0.0
|
||||
for j in range(N):
|
||||
total = 0.0
|
||||
for i in range(n_stage):
|
||||
total += unit[i, j]
|
||||
if total > max_load:
|
||||
max_load = total
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
@njit
|
||||
def jsq_placement(X, pieces, M, stage_weights):
|
||||
n_stage, N = X.shape
|
||||
total_piece = pieces.sum()
|
||||
num_per_group = total_piece // M
|
||||
|
||||
# 1. Compute unit_hotness
|
||||
unit_hotness = np.empty((n_stage, N), dtype=np.float32)
|
||||
for i in range(N):
|
||||
if pieces[i] > 0:
|
||||
for s in range(n_stage):
|
||||
unit_hotness[s, i] = X[s, i] / pieces[i]
|
||||
else:
|
||||
for s in range(n_stage):
|
||||
unit_hotness[s, i] = 0.0
|
||||
|
||||
# 2. Sort by total hotness
|
||||
scores = np.zeros(N, dtype=np.float32)
|
||||
for i in range(N):
|
||||
for s in range(n_stage):
|
||||
scores[i] += unit_hotness[s, i]
|
||||
idx = np.argsort(-scores)
|
||||
|
||||
# 3. Initialization
|
||||
loads = np.zeros((n_stage, M), dtype=np.float32)
|
||||
dev_phy_exp_n = np.zeros(M, dtype=np.int32)
|
||||
deployment = -np.ones((M, num_per_group), dtype=np.int32)
|
||||
dep_ptr = np.zeros(M, dtype=np.int32)
|
||||
|
||||
# 4. Main loop
|
||||
for t in range(N):
|
||||
i = idx[t]
|
||||
used_device = list()
|
||||
for _ in range(pieces[i]):
|
||||
# 4.1 Construct w vector
|
||||
w = np.empty(n_stage, dtype=np.float32)
|
||||
for s in range(n_stage):
|
||||
w[s] = unit_hotness[s, i]
|
||||
|
||||
# 4.2 Compute stage-level maximum load
|
||||
stage_max = np.empty(n_stage, dtype=np.float32)
|
||||
for s in range(n_stage):
|
||||
max_val = loads[s, 0]
|
||||
for k in range(1, M):
|
||||
if loads[s, k] > max_val:
|
||||
max_val = loads[s, k]
|
||||
stage_max[s] = max_val
|
||||
|
||||
# 4.3 Compute denominator
|
||||
denom = np.empty(n_stage, dtype=np.float32)
|
||||
for s in range(n_stage):
|
||||
sum_tmp = 0.0
|
||||
for j in range(M):
|
||||
sum_tmp += loads[s, j] + w[s]
|
||||
denom[s] = sum_tmp / M + 1e-2
|
||||
|
||||
# 4.4 Find best device j
|
||||
best_j = -1
|
||||
best_val = 1e30
|
||||
for j in range(M):
|
||||
if dev_phy_exp_n[j] >= num_per_group:
|
||||
continue
|
||||
if j in used_device:
|
||||
continue
|
||||
score = 0.0
|
||||
for s in range(n_stage):
|
||||
tmp_sj = loads[s, j] + w[s]
|
||||
numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
|
||||
score += stage_weights[s] * (numer_sj / denom[s])
|
||||
if score < best_val:
|
||||
best_val = score
|
||||
best_j = j
|
||||
if best_j == -1:
|
||||
continue
|
||||
|
||||
used_device.append(best_j)
|
||||
|
||||
# 4.5 Update status
|
||||
for s in range(n_stage):
|
||||
loads[s, best_j] += w[s]
|
||||
ptr = dep_ptr[best_j]
|
||||
deployment[best_j, ptr] = i
|
||||
dep_ptr[best_j] += 1
|
||||
dev_phy_exp_n[best_j] += 1
|
||||
|
||||
# Handle remaining -1 values: fill with random elements from range(N) not in current column
|
||||
for rank in range(M):
|
||||
for col in range(num_per_group):
|
||||
if deployment[rank, col] == -1:
|
||||
# Get elements already in current column
|
||||
current_rank_elements = set(deployment[rank, :])
|
||||
# Filter elements from range(N) not in current column
|
||||
available = [
|
||||
x for x in range(N) if x not in current_rank_elements
|
||||
]
|
||||
# Randomly select an available element to fill
|
||||
if len(available) > 0:
|
||||
rand_idx = np.random.randint(0, len(available))
|
||||
deployment[rank, col] = available[rand_idx]
|
||||
elif N > 0:
|
||||
# All unique experts are already in this rank's column, so we can pick any expert randomly.
|
||||
deployment[rank, col] = np.random.randint(0, N)
|
||||
|
||||
return deployment
|
||||
|
||||
|
||||
@njit
|
||||
def slice_values(X, pieces):
|
||||
total_len = 0
|
||||
for i in range(X.shape[0]):
|
||||
total_len += pieces[i]
|
||||
result = np.empty(total_len, dtype=np.float32)
|
||||
idx = 0
|
||||
for i in range(X.shape[0]):
|
||||
val = X[i] / pieces[i]
|
||||
for _ in range(pieces[i]):
|
||||
result[idx] = val
|
||||
idx += 1
|
||||
return result
|
||||
|
||||
|
||||
@njit
|
||||
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
|
||||
simulated_deployment, stage_weights):
|
||||
n_stage, N = X.shape
|
||||
num_group = P // M
|
||||
|
||||
X_all = np.zeros(N, dtype=np.float32)
|
||||
for i in range(n_stage):
|
||||
for j in range(N):
|
||||
X_all[j] += X[i, j]
|
||||
|
||||
sort_idx = np.argsort(np.negative(X_all))
|
||||
X_sorted = X[:, sort_idx]
|
||||
|
||||
unit_load = np.empty(N, dtype=np.float32)
|
||||
for j in range(N):
|
||||
unit_load[j] = X_all[j] / simulated_pieces[j]
|
||||
|
||||
flat_deployment = simulated_deployment.reshape(-1)
|
||||
simulated_load = np.zeros(M, dtype=np.float32)
|
||||
for i in range(flat_deployment.shape[0]):
|
||||
simulated_load[i // (flat_deployment.shape[0] //
|
||||
M)] += unit_load[flat_deployment[i]]
|
||||
|
||||
slice_vals = slice_values(X_all, simulated_pieces)
|
||||
sorted_slices = np.sort(slice_vals)[::-1]
|
||||
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M
|
||||
|
||||
cumulative_slices_used = np.zeros(N, dtype=np.int32)
|
||||
acc = 0
|
||||
for i in range(N):
|
||||
acc += simulated_pieces[sort_idx[i]]
|
||||
cumulative_slices_used[i] = acc
|
||||
|
||||
group_boundary_indices = np.zeros(num_group, dtype=np.int32)
|
||||
for i in range(1, num_group + 1):
|
||||
for j in range(N):
|
||||
if cumulative_slices_used[j] >= i * M:
|
||||
group_boundary_indices[i - 1] = j
|
||||
break
|
||||
|
||||
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
|
||||
slices_used_per_group[0] = group_boundary_indices[0]
|
||||
for i in range(1, num_group):
|
||||
slices_used_per_group[
|
||||
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
|
||||
slices_used_per_group = M - slices_used_per_group
|
||||
|
||||
loads = np.zeros(M, dtype=np.float32)
|
||||
pieces = np.zeros(N, dtype=np.int32)
|
||||
num_remain_slice = P - N
|
||||
current_idx = 0
|
||||
|
||||
for g in range(num_group):
|
||||
window = X_sorted[:, current_idx:current_idx + 2 * M]
|
||||
low = max(0, current_idx + M - N)
|
||||
high = min(num_remain_slice, M - 1)
|
||||
|
||||
while (high - low) > 1:
|
||||
mid = int((high + low) // 2)
|
||||
keep = M - mid
|
||||
current_group = window[:, :keep]
|
||||
current_pieces = compute_piece_counts(current_group, M,
|
||||
stage_weights)
|
||||
current_pieces = np.maximum(current_pieces, 1)
|
||||
current_slice = slice_values(current_group.sum(0), current_pieces)
|
||||
current_slice_sorted = np.sort(current_slice)
|
||||
current_loads = loads + current_slice_sorted
|
||||
current_max: np.float32 = np.max(current_loads)
|
||||
current_min: np.float32 = np.min(current_loads)
|
||||
current_slope = (current_max - current_min) / M
|
||||
next_slope: np.float32 = np.max(simulated_slopes[current_idx +
|
||||
keep:])
|
||||
|
||||
if abs(current_slope) > abs(next_slope):
|
||||
low = mid
|
||||
else:
|
||||
high = mid
|
||||
|
||||
S = high
|
||||
keep = M - S
|
||||
current_group = window[:, :keep]
|
||||
current_pieces = compute_piece_counts(current_group, M, stage_weights)
|
||||
|
||||
for i in range(keep):
|
||||
pieces[sort_idx[current_idx + i]] = current_pieces[i]
|
||||
|
||||
current_slice = slice_values(current_group.sum(0), current_pieces)
|
||||
current_slice_sorted = np.sort(current_slice)
|
||||
loads += current_slice_sorted
|
||||
loads = np.sort(loads)[::-1]
|
||||
|
||||
current_idx += keep
|
||||
num_remain_slice -= S
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
@njit
|
||||
def compute_objective(deployment, X, pieces):
|
||||
M, P = deployment.shape
|
||||
loads = np.zeros(M)
|
||||
|
||||
for i in range(M):
|
||||
for j in range(P):
|
||||
expert = deployment[i, j]
|
||||
if pieces[expert] == 0:
|
||||
continue
|
||||
loads[i] += X[expert] / pieces[expert]
|
||||
|
||||
mean_load = np.mean(loads)
|
||||
max_load: np.float32 = np.max(loads)
|
||||
obj = max_load / mean_load
|
||||
return obj, loads
|
||||
|
||||
|
||||
@njit
|
||||
def auto_fix_new_placement(old_placement, new_placement):
|
||||
"""
|
||||
Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both
|
||||
old_placement and new_placement remain in their original positions from old_placement.
|
||||
New elements (unique to new_placement) will fill the remaining empty positions.
|
||||
|
||||
Args:
|
||||
old_placement: Old deployment matrix with shape (num_ranks, num_experts)
|
||||
new_placement: New deployment matrix to be fixed, must have the same shape as old_placement
|
||||
|
||||
Returns:
|
||||
fixed_new: adjusted version of the new_placement matrix
|
||||
"""
|
||||
num_ranks, num_experts = old_placement.shape
|
||||
fixed_new = np.empty_like(new_placement)
|
||||
|
||||
max_expert_old = old_placement.max() if num_experts > 0 else 0
|
||||
max_expert_new = new_placement.max() if num_experts > 0 else 0
|
||||
max_expert = max(max_expert_old, max_expert_new)
|
||||
|
||||
for rank_id in range(num_ranks):
|
||||
old_row = old_placement[rank_id]
|
||||
new_row = new_placement[rank_id]
|
||||
|
||||
index_array = np.full((max_expert + 1, num_experts),
|
||||
-1,
|
||||
dtype=np.int32)
|
||||
count_array = np.zeros(max_expert + 1, dtype=np.int32)
|
||||
|
||||
for idx in range(num_experts):
|
||||
val = old_row[idx]
|
||||
if val >= 0 and val <= max_expert:
|
||||
pos = count_array[val]
|
||||
index_array[val, pos] = idx
|
||||
count_array[val] += 1
|
||||
|
||||
old_counter = np.zeros(max_expert + 1, dtype=np.int32)
|
||||
for idx in range(num_experts):
|
||||
val = old_row[idx]
|
||||
if val >= 0 and val <= max_expert:
|
||||
old_counter[val] += 1
|
||||
|
||||
retain_elements = np.empty(num_experts, dtype=new_placement.dtype)
|
||||
new_elements = np.empty(num_experts, dtype=new_placement.dtype)
|
||||
retain_ptr = 0
|
||||
new_ptr = 0
|
||||
|
||||
for val in new_row:
|
||||
if val >= 0 and val <= max_expert and old_counter[val] > 0:
|
||||
retain_elements[retain_ptr] = val
|
||||
retain_ptr += 1
|
||||
old_counter[val] -= 1
|
||||
else:
|
||||
new_elements[new_ptr] = val
|
||||
new_ptr += 1
|
||||
|
||||
current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype)
|
||||
|
||||
for i in range(retain_ptr):
|
||||
val = retain_elements[i]
|
||||
if val >= 0 and val <= max_expert:
|
||||
pos = count_array[val] - 1
|
||||
if pos >= 0:
|
||||
idx = index_array[val, pos]
|
||||
current_fixed[idx] = val
|
||||
count_array[val] -= 1
|
||||
|
||||
empty_indices = np.empty(num_experts, dtype=np.int32)
|
||||
empty_ptr = 0
|
||||
for idx in range(num_experts):
|
||||
if current_fixed[idx] == -1:
|
||||
empty_indices[empty_ptr] = idx
|
||||
empty_ptr += 1
|
||||
|
||||
for i in range(new_ptr):
|
||||
if i < empty_ptr:
|
||||
current_fixed[empty_indices[i]] = new_elements[i]
|
||||
|
||||
fixed_new[rank_id] = current_fixed
|
||||
|
||||
return fixed_new
|
||||
|
||||
|
||||
class FlashLB(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
self.par_history: Dict[int, float] = {}
|
||||
self.hotness_window: Dict[int, deque[float]] = {}
|
||||
self.max_stage_window = (config.max_stage_window if hasattr(
|
||||
config, "max_stage_window") else 1)
|
||||
self.buffer_expert_layer_num = (
|
||||
config.buffer_expert_layer_num if hasattr(
|
||||
config, "buffer_expert_layer_num") else 58)
|
||||
self.threshold_ratio = (config.threshold_ratio if hasattr(
|
||||
config, "threshold_ratio") else 0)
|
||||
|
||||
def compute_expert_hotness(self, num_of_expert: int,
|
||||
deployment: np.ndarray, rank_load: np.ndarray):
|
||||
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
|
||||
deployment_flat = deployment.ravel()
|
||||
rank_load_flat = rank_load.ravel()
|
||||
np.add.at(hotness, deployment_flat, rank_load_flat)
|
||||
return hotness
|
||||
|
||||
def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray):
|
||||
n_stage, N = hotness.shape
|
||||
if np.any(deployment < 0):
|
||||
print(f"Invalid deployment with negative values: {deployment}")
|
||||
raise ValueError("Deployment table contains negative values.")
|
||||
counts = np.bincount(deployment.reshape(-1), minlength=N)
|
||||
unit_hotness = np.divide(hotness,
|
||||
counts,
|
||||
out=np.zeros_like(hotness, dtype=float),
|
||||
where=counts != 0)
|
||||
stage_par = np.zeros(n_stage)
|
||||
for i in range(n_stage):
|
||||
stage_load = unit_hotness[i][deployment].sum(-1)
|
||||
stage_par[i] = stage_load.max() / stage_load.mean()
|
||||
return stage_par.mean()
|
||||
|
||||
def group_based_adaptive_bloating(self,
|
||||
X,
|
||||
P,
|
||||
M,
|
||||
stage_weights=None,
|
||||
recorsive=False):
|
||||
n_stage, N = X.shape
|
||||
if stage_weights is None:
|
||||
stage_weights = np.ones(n_stage, dtype=np.float32)
|
||||
|
||||
if recorsive:
|
||||
(
|
||||
simulated_deployment,
|
||||
simulated_pieces,
|
||||
) = self.group_based_adaptive_bloating(X,
|
||||
P,
|
||||
M,
|
||||
stage_weights,
|
||||
recorsive=False)
|
||||
else:
|
||||
simulated_pieces = compute_piece_counts(X, P, stage_weights)
|
||||
simulated_deployment = jsq_placement(X, simulated_pieces, M,
|
||||
stage_weights)
|
||||
|
||||
pieces = group_based_adaptive_bloating_kernel(
|
||||
X.astype(np.float32),
|
||||
P,
|
||||
M,
|
||||
simulated_pieces.astype(np.int32),
|
||||
simulated_deployment.astype(np.int32),
|
||||
stage_weights.astype(np.float32),
|
||||
)
|
||||
|
||||
deployment = jsq_placement(X, pieces, M, stage_weights)
|
||||
|
||||
X_all = X.sum(0)
|
||||
unit_load = np.divide(X_all,
|
||||
pieces,
|
||||
out=np.zeros_like(X_all, dtype=float),
|
||||
where=pieces != 0)
|
||||
load = unit_load[deployment].sum(-1)
|
||||
|
||||
sim_unit_load = X_all / simulated_pieces
|
||||
sim_load = sim_unit_load[simulated_deployment].sum(-1)
|
||||
|
||||
if load.max() > sim_load.max():
|
||||
return simulated_deployment, simulated_pieces
|
||||
return deployment, pieces
|
||||
|
||||
def need_update(self, current_par, layer_id=0):
|
||||
threshold = self.par_history.get(layer_id, 0.0)
|
||||
return current_par >= self.threshold_ratio * threshold
|
||||
|
||||
def compute_stage_weight(self, hotness):
|
||||
n_stage = hotness.shape[0]
|
||||
stage_weights = np.zeros(n_stage)
|
||||
for i in range(n_stage):
|
||||
stage_weights[i] = hotness[i].sum()
|
||||
|
||||
stage_weights = stage_weights / stage_weights.max()
|
||||
return stage_weights
|
||||
|
||||
def rebalance_layer(self, deployment, hotness, layer_id=0):
|
||||
num_rank, expert_per_rank = deployment.shape
|
||||
num_expert = np.unique(deployment.reshape(-1)).shape[0]
|
||||
num_of_redundant_expert = num_rank * expert_per_rank - num_expert
|
||||
|
||||
current_par = self.compute_rank_load(deployment, hotness)
|
||||
|
||||
if not self.need_update(current_par, layer_id):
|
||||
return deployment, current_par, current_par
|
||||
|
||||
stage_weights = self.compute_stage_weight(hotness)
|
||||
new_deployment, _ = self.group_based_adaptive_bloating(
|
||||
hotness,
|
||||
num_expert + num_of_redundant_expert,
|
||||
num_rank,
|
||||
stage_weights,
|
||||
recorsive=False,
|
||||
)
|
||||
if np.any(new_deployment < 0):
|
||||
print(f"{new_deployment=}")
|
||||
new_par = self.compute_rank_load(new_deployment, hotness)
|
||||
|
||||
return new_deployment, new_par, current_par
|
||||
|
||||
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
|
||||
for layer in range(num_layer):
|
||||
if layer not in self.hotness_window:
|
||||
self.hotness_window[layer] = deque(
|
||||
maxlen=self.max_stage_window)
|
||||
hotness = self.compute_expert_hotness(num_expert,
|
||||
deployment[layer],
|
||||
rank_load[layer])
|
||||
self.hotness_window[layer].append(hotness)
|
||||
|
||||
def compress_by_avg_pooling_fast_nd(self, arr, m):
|
||||
n, d = arr.shape
|
||||
idx = (np.arange(n) * m // n)
|
||||
result = np.zeros((m, d))
|
||||
counts = np.zeros((m, 1))
|
||||
np.add.at(result, idx, arr)
|
||||
np.add.at(counts, idx, 1)
|
||||
return result / counts
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
current_deployment = np.array(current_expert_table)
|
||||
expert_workload = np.array(expert_workload)
|
||||
expert_workload += 1
|
||||
num_layer = expert_workload.shape[0]
|
||||
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
|
||||
self.register_hotness(current_deployment, expert_workload, num_layer,
|
||||
num_expert)
|
||||
|
||||
new_deployment = current_deployment.copy()
|
||||
|
||||
layers_need_update = np.arange(num_layer)
|
||||
|
||||
new_par = np.zeros(layers_need_update.shape[0])
|
||||
current_par = np.zeros(layers_need_update.shape[0])
|
||||
for i, layer in enumerate(layers_need_update):
|
||||
hotness = np.array(self.hotness_window[layer])
|
||||
if hotness.shape[0] > self.max_stage_window:
|
||||
hotness = self.compress_by_avg_pooling_fast_nd(
|
||||
hotness, self.max_stage_window)
|
||||
|
||||
(
|
||||
new_deployment[layer],
|
||||
new_par[i],
|
||||
current_par[i],
|
||||
) = self.rebalance_layer(current_deployment[layer],
|
||||
hotness,
|
||||
layer_id=layer)
|
||||
|
||||
priority = new_par / current_par
|
||||
priority_idx = np.argsort(priority)
|
||||
priority_idx = priority_idx[priority[priority_idx] <
|
||||
1][:self.buffer_expert_layer_num]
|
||||
|
||||
if np.all(expert_workload == 1):
|
||||
for _, layer in enumerate(layers_need_update):
|
||||
self.hotness_window[layer].pop()
|
||||
return False, np.array([], dtype=int), current_deployment
|
||||
change = len(priority_idx) > 0
|
||||
if change:
|
||||
for idx in priority_idx:
|
||||
self.par_history[layers_need_update[idx]] = new_par[idx]
|
||||
|
||||
layers_need_update = priority_idx
|
||||
deployment = current_deployment
|
||||
for layer in layers_need_update:
|
||||
deployment[layer] = auto_fix_new_placement(
|
||||
current_deployment[layer], new_deployment[layer])
|
||||
|
||||
return change, layers_need_update, deployment
|
||||
|
||||
|
||||
def generate_layered_experts(num_layers=58,
|
||||
layer_shape=(32, 9),
|
||||
expert_min=0,
|
||||
expert_max=255):
|
||||
"""
|
||||
Generate expert deployment matrix meeting the following conditions:
|
||||
- Total of num_layers layers
|
||||
- Each layer has shape layer_shape (32,9)
|
||||
- Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer
|
||||
|
||||
Args:
|
||||
num_layers: Number of layers, default 58
|
||||
layer_shape: Shape of a single layer, default (32,9)
|
||||
expert_min: Minimum expert ID, default 0
|
||||
expert_max: Maximum expert ID, default 255
|
||||
Returns:
|
||||
torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1])
|
||||
"""
|
||||
# 1. Basic parameter calculation
|
||||
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
|
||||
layer_total = layer_shape[0] * layer_shape[
|
||||
1] # Total elements in a single layer: 32*9=288
|
||||
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
|
||||
|
||||
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
|
||||
assert layer_total >= expert_num, (
|
||||
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, "
|
||||
"cannot cover all experts")
|
||||
|
||||
# 3. Generate layers one by one
|
||||
layers = []
|
||||
for _ in range(num_layers):
|
||||
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
|
||||
full_experts = torch.arange(expert_min,
|
||||
expert_max + 1,
|
||||
dtype=torch.int64) # shape (256,)
|
||||
|
||||
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
|
||||
extra_experts = torch.randint(expert_min,
|
||||
expert_max + 1,
|
||||
size=(extra_slots, ),
|
||||
dtype=torch.int64) # shape (32,)
|
||||
|
||||
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
|
||||
layer_flat = torch.cat([full_experts, extra_experts],
|
||||
dim=0) # shape (288,)
|
||||
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
|
||||
shuffle_idx = torch.randperm(layer_flat.shape[0])
|
||||
layer_shuffled = layer_flat[shuffle_idx]
|
||||
|
||||
# 3.4 Reshape to layer_shape (32,9)
|
||||
layer = layer_shuffled.reshape(layer_shape)
|
||||
layers.append(layer)
|
||||
|
||||
# 4. Stack all layers to get the final tensor
|
||||
return torch.stack(layers, dim=0) # shape (58,32,9)
|
||||
|
||||
|
||||
def warm_up():
|
||||
exam_config = DynamicConfig()
|
||||
exam_config.ep_worldsize = 32
|
||||
exam_config.num_die_per_host = 16
|
||||
algo = FlashLB(exam_config)
|
||||
# Generate target tensor
|
||||
expert_tensor = generate_layered_experts(num_layers=58,
|
||||
layer_shape=(32, 9))
|
||||
|
||||
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))
|
||||
30
vllm_npu/eplb/core/policy/policy_random.py
Normal file
30
vllm_npu/eplb/core/policy/policy_random.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
|
||||
import copy
|
||||
import random
|
||||
|
||||
from .policy_abstract import DynamicConfig, EplbPolicy
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
class RandomLoadBalance(EplbPolicy):
|
||||
|
||||
def __init__(self, config: DynamicConfig):
|
||||
super().__init__(config)
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
new_table = copy.deepcopy(current_expert_table)
|
||||
num_layers = len(current_expert_table)
|
||||
|
||||
for i in range(num_layers):
|
||||
# randomly choose two card
|
||||
# indices = random.sample(range(num_card), 2)
|
||||
indices = [3, 1]
|
||||
|
||||
# swap redundant experts
|
||||
expert_id_to_exchange = new_table[i][indices[0]][-1].clone()
|
||||
new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1]
|
||||
new_table[i][indices[1]][-1] = expert_id_to_exchange
|
||||
|
||||
return 1, [-i for i in range(num_layers)], new_table
|
||||
209
vllm_npu/eplb/eplb_updator.py
Normal file
209
vllm_npu/eplb/eplb_updator.py
Normal file
@@ -0,0 +1,209 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator.
|
||||
import numpy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.eplb.core.eplb_utils import EPLBParamUtils
|
||||
from vllm_npu.eplb.core.eplb_worker import EplbProcess
|
||||
|
||||
|
||||
class EplbUpdator:
|
||||
|
||||
def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
|
||||
process):
|
||||
self.ascend_config = ascend_config
|
||||
self.init_eplb(self.ascend_config.expert_map_path, process)
|
||||
self.eplb_loader = loader
|
||||
self.eplb_process = eplb_process
|
||||
self.shared_dict = self.eplb_process.shared_dict
|
||||
|
||||
def set_adaptor(self, adaptor):
|
||||
self.adaptor = adaptor
|
||||
self.num_moe_layers = self.adaptor.num_moe_layers
|
||||
self.global_expert_num = self.adaptor.global_expert_num
|
||||
|
||||
def init_eplb(self, expert_map_path, process):
|
||||
self.rank_id = dist.get_rank()
|
||||
self.num_expert_load_gather = 10
|
||||
self.periodic_load_gather = True
|
||||
self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update
|
||||
EPLBParamUtils.check_iterations(self.num_iterations_eplb_update)
|
||||
self.expert_map_path = expert_map_path
|
||||
self.expert_map_record_path = self.ascend_config.expert_map_record_path
|
||||
|
||||
try:
|
||||
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
|
||||
self.num_expert_load_gather = self.num_iterations_eplb_update
|
||||
self.periodic_load_gather = False
|
||||
except Exception:
|
||||
self.num_expert_load_gather = self.num_iterations_eplb_update
|
||||
self.periodic_load_gather = False
|
||||
|
||||
self.expert_map_initialized = False
|
||||
self.gate_eplb = self.ascend_config.gate_eplb
|
||||
|
||||
self.reqs = []
|
||||
self.update_info_all = []
|
||||
|
||||
self.cur_iterations: torch.int64 = 0
|
||||
|
||||
self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations
|
||||
EPLBParamUtils.check_iterations(self.num_wait_worker_iterations)
|
||||
|
||||
self.process = process
|
||||
|
||||
logger.info(
|
||||
f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
|
||||
|
||||
def update_iteration(self):
|
||||
self.cur_iterations += 1
|
||||
if self.cur_iterations == (self.num_iterations_eplb_update + \
|
||||
self.num_wait_worker_iterations + self.num_moe_layers):
|
||||
if self.expert_map_record_path is not None:
|
||||
self.adaptor._export_tensor_to_file(
|
||||
self.shared_dict["expert_maps"],
|
||||
self.expert_map_record_path)
|
||||
|
||||
self.adaptor.model.clear_all_moe_loads()
|
||||
if not self.gate_eplb:
|
||||
self.cur_iterations = 0
|
||||
|
||||
def get_update_info_flag(self):
|
||||
return self.cur_iterations == (self.num_iterations_eplb_update +
|
||||
self.num_wait_worker_iterations - 1)
|
||||
|
||||
def wakeup_eplb_worker_flag(self):
|
||||
return self.cur_iterations == (self.num_iterations_eplb_update - 1)
|
||||
|
||||
def update_expert_weight_flag(self):
|
||||
weight_update_counter = self.cur_iterations - (
|
||||
self.num_iterations_eplb_update + self.num_wait_worker_iterations)
|
||||
return (weight_update_counter >= 0
|
||||
and weight_update_counter < self.num_moe_layers)
|
||||
|
||||
def get_init_expert_map(self):
|
||||
try:
|
||||
if not self.expert_map_initialized:
|
||||
self.shared_dict[
|
||||
"expert_maps"] = self.adaptor.get_init_expert_map_from_file(
|
||||
self.num_moe_layers, self.expert_map_path)
|
||||
self.expert_map_initialized = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}",
|
||||
exc_info=True)
|
||||
|
||||
def wakeup_eplb_worker(self):
|
||||
self.eplb_process.planner_q.put(1)
|
||||
|
||||
def forward_before(self):
|
||||
if self.update_expert_weight_flag():
|
||||
(expert_send_info, expert_recv_info, updated_expert_map,
|
||||
log2phy_map, layer_id) = self.update_info_all.pop(0)
|
||||
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
|
||||
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
|
||||
updated_expert_map_this_rank = torch.from_numpy(
|
||||
numpy.array(updated_expert_map))
|
||||
self.eplb_loader.generate_expert_d2d_transfer_task(
|
||||
expert_send_info, expert_recv_info,
|
||||
updated_expert_map_this_rank,
|
||||
layer_id + self.adaptor.num_dense_layers)
|
||||
|
||||
# set asynchronous stream for d2d expert weight update
|
||||
self.reqs = []
|
||||
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
|
||||
|
||||
def take_update_info_from_eplb_process(self):
|
||||
# Batch after eplb process being triggered, get update info provided by eplb process
|
||||
if self.get_update_info_flag():
|
||||
self.update_info_all = self.eplb_process.block_update_q.get()
|
||||
|
||||
def forward_end(self):
|
||||
if self.wakeup_eplb_worker_flag():
|
||||
self.compute_and_set_moe_load(is_clear=True)
|
||||
self.wakeup_eplb_worker()
|
||||
|
||||
if self.update_expert_weight_flag(
|
||||
) and self.expert_map_record_path is None:
|
||||
self.eplb_loader.update_expert_map_and_weight(self.reqs)
|
||||
|
||||
self.update_iteration()
|
||||
|
||||
def compute_and_set_moe_load(self, is_clear=False):
|
||||
local_load = self.adaptor.get_rank_expert_workload()
|
||||
|
||||
self._gather_buffer = None
|
||||
if dist.is_initialized():
|
||||
self.world_size = dist.get_world_size()
|
||||
self.device = local_load.device
|
||||
if self._gather_buffer is None:
|
||||
shape = (self.world_size, *local_load.shape)
|
||||
self._gather_buffer = torch.empty(shape,
|
||||
dtype=local_load.dtype,
|
||||
device=self.device)
|
||||
|
||||
dist.all_gather_into_tensor(self._gather_buffer, local_load)
|
||||
|
||||
moe_load = self._gather_buffer.permute(1, 0, 2)
|
||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||
logger.debug(
|
||||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
||||
)
|
||||
else:
|
||||
moe_load = local_load.unsqueeze(1)
|
||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||
logger.debug(
|
||||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
|
||||
)
|
||||
return moe_load
|
||||
|
||||
def warm_up_eplb(self):
|
||||
|
||||
self.get_init_expert_map()
|
||||
self.compute_and_set_moe_load()
|
||||
|
||||
src_tensor = torch.empty((1, ), device=self.device)
|
||||
self_rank = dist.get_rank()
|
||||
|
||||
comm_op_list = []
|
||||
|
||||
for dst_rank in range(self.world_size):
|
||||
if dst_rank == self_rank:
|
||||
continue
|
||||
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||
|
||||
for src_rank in range(self.world_size):
|
||||
if src_rank == self_rank:
|
||||
continue
|
||||
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
|
||||
if comm_op_list:
|
||||
reqs = dist.batch_isend_irecv(comm_op_list)
|
||||
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Clean up the EPLB process.
|
||||
"""
|
||||
if self.process.is_alive():
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
logger.info("[ModelRunner] EPLB process terminated")
|
||||
77
vllm_npu/eplb/utils.py
Normal file
77
vllm_npu/eplb/utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register.
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_expert_map(self, layer_id):
|
||||
return self.model.layers[layer_id].mlp.experts.get_map()
|
||||
|
||||
|
||||
def get_log2phy_map(self, layer_id):
|
||||
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
|
||||
|
||||
|
||||
def get_all_expert_map(self, num_moe_layers):
|
||||
all_loads = []
|
||||
num_dense_layers = self.num_dense_layers if hasattr(
|
||||
self, "num_dense_layers") else 0
|
||||
for layer_id in range(num_moe_layers):
|
||||
load_tensor = self.get_expert_map(
|
||||
layer_id + num_dense_layers) # (num_experts_per_layer,)
|
||||
all_loads.append(load_tensor)
|
||||
|
||||
return torch.stack(all_loads, dim=0)
|
||||
|
||||
|
||||
def get_all_moe_loads(self):
|
||||
num_dense_layers = self.num_dense_layers if hasattr(
|
||||
self, "num_dense_layers") else 0
|
||||
all_moe_loads = torch.stack(
|
||||
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
|
||||
for layer_id in range(self.num_moe_layers)],
|
||||
dim=0
|
||||
)
|
||||
return all_moe_loads
|
||||
|
||||
|
||||
def clear_all_moe_loads(self):
|
||||
num_dense_layers = self.num_dense_layers if hasattr(
|
||||
self, "num_dense_layers") else 0
|
||||
for layer_id in range(self.num_moe_layers):
|
||||
self.model.layers[layer_id +
|
||||
num_dense_layers].mlp.experts.clear_moe_load()
|
||||
|
||||
|
||||
def model_register(model, model_config):
|
||||
model.get_expert_map = types.MethodType(get_expert_map, model)
|
||||
model.get_log2phy_map = types.MethodType(get_log2phy_map, model)
|
||||
model.get_all_expert_map = types.MethodType(get_all_expert_map, model)
|
||||
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
|
||||
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
|
||||
|
||||
config = model_config.hf_config
|
||||
|
||||
if config.model_type == "qwen3_moe":
|
||||
model.num_moe_layers = config.num_hidden_layers
|
||||
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
|
||||
model.num_dense_layers = config.first_k_dense_replace
|
||||
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
|
||||
else:
|
||||
raise NotImplementedError("EPLB is not supported.")
|
||||
0
vllm_npu/lora/__init__.py
Normal file
0
vllm_npu/lora/__init__.py
Normal file
113
vllm_npu/lora/lora_ops.py
Normal file
113
vllm_npu/lora/lora_ops.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
return torch.ops._C_ascend.bgmv_shrink(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
scaling,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset,
|
||||
slice_size)
|
||||
356
vllm_npu/lora/punica_npu.py
Normal file
356
vllm_npu/lora/punica_npu.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
|
||||
if is_310p():
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
else:
|
||||
from vllm_npu.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
|
||||
from vllm_npu.lora.utils import refresh_all_lora_classes
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
refresh_all_lora_classes()
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_inputs)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self._expand_slice_decode)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, scale: float):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
shrink_fun: Callable = (self._shrink_prefill
|
||||
if self.is_prefill else self._shrink_decode)
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||
scale)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
# Embedding layer only need expand op
|
||||
expand_fun: Callable = (self._expand_prefill
|
||||
if self.is_prefill else self._expand_decode)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)+lora_bias_stacked[i]
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (Tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
if lora_bias_stacked is not None:
|
||||
assert len(lora_bias_stacked) == len(output_slices)
|
||||
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = tuple(
|
||||
torch.zeros(
|
||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
for _ in range(len(output_slices)))
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
110
vllm_npu/lora/utils.py
Normal file
110
vllm_npu/lora/utils.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||
|
||||
from vllm_npu.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_npu.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return (type(source_layer) is AscendQKVParallelLinear
|
||||
and len(packed_modules_list) == 3)
|
||||
|
||||
|
||||
def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedQKVParallelLinearWithLoRA)
|
||||
105
vllm_npu/meta_registration.py
Normal file
105
vllm_npu/meta_registration.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
# This file provides a template and registration utilities for writing "meta" implementations
|
||||
# of custom operators in Python for the vllm_npu project.
|
||||
#
|
||||
# We offer two ways to implement meta implementations for custom ops:
|
||||
# 1. Python meta implementation (as shown in this file): Write a Python function that
|
||||
# takes the same arguments as your operator and returns empty tensors with the correct
|
||||
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only
|
||||
# used in Python.
|
||||
# 2. C++ meta implementation: You can also implement the meta function in C++ for better
|
||||
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp`
|
||||
# for examples of C++ meta implementations and how to register them.
|
||||
#
|
||||
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
|
||||
# is essential for supporting `torch.compile` and aclgraph.
|
||||
|
||||
# How to add a new meta implementation in Python:
|
||||
# -------------------------------------
|
||||
# 1. Write a Python function that takes the same arguments as your operator, and returns
|
||||
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes.
|
||||
# Do NOT perform any real computation or allocate device memory.
|
||||
#
|
||||
# 2. Register your meta function using `register_meta_if_necessary`, providing:
|
||||
# - The namespace (usually "_C_ascend" for custom ops)
|
||||
# - The operator name (as registered in C++)
|
||||
# - The Python meta function
|
||||
# - (Optional) The overload name, if your op has overloads
|
||||
#
|
||||
# 3. The registration utility will check if a meta implementation already exists for your op,
|
||||
# and only register if necessary. This avoids duplicate registrations.
|
||||
#
|
||||
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
|
||||
#
|
||||
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
|
||||
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
|
||||
# and aclgraph.
|
||||
#
|
||||
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
|
||||
|
||||
lib = Library("_C_ascend", "IMPL")
|
||||
|
||||
|
||||
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
||||
if overload != "":
|
||||
op_name = op_name + "." + overload
|
||||
schema_to_find = ns + "::" + op_name
|
||||
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
|
||||
"Meta")
|
||||
if schema_to_find in meta_impl_list:
|
||||
return
|
||||
lib.impl(op_name, fn, "Meta")
|
||||
|
||||
|
||||
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool):
|
||||
|
||||
num_tokens = positions.numel()
|
||||
query_hidden_size = query.numel() // num_tokens
|
||||
key_hidden_size = key.numel() // num_tokens
|
||||
num_heads = query_hidden_size // head_size
|
||||
num_kv_heads = key_hidden_size // head_size
|
||||
|
||||
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
|
||||
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
|
||||
return query_dst, key_dst
|
||||
|
||||
|
||||
def get_masked_input_and_mask_meta(input: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int):
|
||||
|
||||
masked_input = torch.empty_like(input)
|
||||
mask = torch.empty_like(input).to(torch.bool)
|
||||
|
||||
return masked_input, mask
|
||||
|
||||
|
||||
def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.Tensor, slice_offset: int,
|
||||
slice_size: int):
|
||||
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
lora_indices: torch.Tensor, seq_len: torch.Tensor,
|
||||
y: torch.Tensor, slice_offset: int, slice_size: int):
|
||||
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
register_meta_if_necessary("_C_ascend", "rotary_embedding",
|
||||
rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
|
||||
get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)
|
||||
48
vllm_npu/models/__init__.py
Normal file
48
vllm_npu/models/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from vllm import ModelRegistry
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
|
||||
|
||||
def register_model():
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"vllm_npu.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"vllm_npu.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"vllm_npu.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
if envs_ascend.USE_OPTIMIZED_MODEL:
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"vllm_npu.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
||||
)
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5OmniModel",
|
||||
"vllm_npu.models.qwen2_5_omni_thinker:AscendQwen2_5OmniThinkerForConditionalGeneration"
|
||||
)
|
||||
else:
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"vllm_npu.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_npu.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_npu.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_npu.models.qwen3_next:CustomQwen3NextForCausalLM")
|
||||
633
vllm_npu/models/deepseek_v3_2.py
Normal file
633
vllm_npu/models/deepseek_v3_2.py
Normal file
@@ -0,0 +1,633 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# # Adapted from
|
||||
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (divide, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (
|
||||
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
|
||||
get_spec_layer_idx_from_weight_name)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.models.layers.sfa import (AscendSFAModules,
|
||||
AscendSparseFlashAttention, Indexer)
|
||||
from vllm_npu.ops.common_fused_moe import AscendFusedMoE
|
||||
from vllm_npu.ops.linear import AscendLinearBase
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Rewrite this init func mainly for removing cuda-hard code
|
||||
nn.Module.__init__(self)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
assert hasattr(config, "index_topk")
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device=current_platform.device_type)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
|
||||
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill=True,
|
||||
is_force_scatter=False
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % self.tp_size == 0
|
||||
self.num_local_heads = num_heads // self.tp_size
|
||||
self.layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
return_bias=False,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.dim: int = config.hidden_size # 7168
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
self.n_heads: int = 64 # 64
|
||||
self.head_dim: int = 128 # 128
|
||||
self.index_topk: int = 2048 # 2048
|
||||
self.indexer = Indexer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
dim=self.dim,
|
||||
n_heads=self.n_heads,
|
||||
head_dim=self.head_dim,
|
||||
index_topk=self.index_topk,
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
sfa_modules = AscendSFAModules(
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
rotary_emb=self.rotary_emb,
|
||||
indexer=self.indexer)
|
||||
|
||||
self.sfa_attn = AscendSparseFlashAttention(
|
||||
self.hidden_size,
|
||||
self.enable_shared_expert_dp,
|
||||
self.debug_layer_idx,
|
||||
self.first_k_dense_replace,
|
||||
self.tp_size,
|
||||
sfa_modules,
|
||||
self.num_local_heads,
|
||||
self.scaling,
|
||||
self.layers,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
self.q_lora_rank,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_head_dim,
|
||||
self.v_head_dim,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.prefix = prefix
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer=None) -> None:
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
self.layers = config.num_hidden_layers
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = CustomDeepseekV2SFAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank
|
||||
if hasattr(config, "q_lora_rank") else None,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
parallel_config=parallel_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
if self.mlp.gate.e_score_correction_bias is not None:
|
||||
self.mlp.gate.e_score_correction_bias.data = (
|
||||
self.mlp.gate.e_score_correction_bias.data.to(
|
||||
dtype=torch.get_default_dtype()))
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
self.tp_group = get_tp_group().device_group
|
||||
|
||||
|
||||
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
# `packed_modules_mapping` needs to be modified before
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
# quant_method for relevant layers during initialization.
|
||||
self.fuse_qkv_a_proj = hasattr(
|
||||
config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
if self.fuse_qkv_a_proj:
|
||||
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
|
||||
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.expert_weights: list[Any] = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (config.num_hidden_layers -
|
||||
config.first_k_dense_replace)
|
||||
self.num_expert_groups = config.n_group
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, DeepseekV2DecoderLayer)
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_moe is None:
|
||||
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
|
||||
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
# NOTE: This `load_weights` is mainly copied from
|
||||
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
|
||||
# to fix CI, and it is different from the implementation in main
|
||||
# TODO: support eplb style load_weights
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
""""""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if "module" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=False)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__
|
||||
0
vllm_npu/models/layers/__init__.py
Normal file
0
vllm_npu/models/layers/__init__.py
Normal file
193
vllm_npu/models/layers/mla.py
Normal file
193
vllm_npu/models/layers/mla.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.mla import MLAModules
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.attention import Attention
|
||||
from vllm.model_executor.layers.mla import \
|
||||
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
|
||||
else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
|
||||
|
||||
|
||||
# TODO(whx): adapt v0.11.0 and DSA
|
||||
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
hf_config = get_current_vllm_config().model_config.hf_config
|
||||
self.enable_shared_expert_dp = get_ascend_config(
|
||||
).enable_shared_expert_dp
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
self.first_k_dense_replace = hf_config.first_k_dense_replace
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layers = hf_config.num_hidden_layers
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.mla_attn = Attention(
|
||||
num_heads=num_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
|
||||
q_b_proj=mla_modules.q_b_proj,
|
||||
q_a_layernorm=mla_modules.q_a_layernorm,
|
||||
q_proj=mla_modules.q_proj,
|
||||
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=mla_modules.kv_a_layernorm,
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
o_proj=mla_modules.o_proj,
|
||||
)
|
||||
else:
|
||||
self.mla_attn = MLAAttention(
|
||||
num_heads=self.num_heads,
|
||||
scale=scale,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
indexer=mla_modules.indexer,
|
||||
# extra args
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
|
||||
q_b_proj=mla_modules.q_b_proj,
|
||||
q_a_layernorm=mla_modules.q_a_layernorm,
|
||||
q_proj=mla_modules.q_proj,
|
||||
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=mla_modules.kv_a_layernorm,
|
||||
o_proj=mla_modules.o_proj,
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
|
||||
self.prefix)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
def mla_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if forward_context.attn_metadata:
|
||||
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states,
|
||||
kv_cache, attn_metadata, need_gather_q_kv,
|
||||
output)
|
||||
return
|
||||
|
||||
|
||||
def mla_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mla_forward",
|
||||
op_func=mla_forward,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mla_forward_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
233
vllm_npu/models/layers/sfa.py
Normal file
233
vllm_npu/models/layers/sfa.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAModules:
|
||||
q_a_proj: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: torch.nn.Module
|
||||
kv_a_layernorm: torch.nn.Module
|
||||
kv_b_proj: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
indexer: torch.nn.Module
|
||||
|
||||
|
||||
class AscendSparseFlashAttention(MultiHeadLatentAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
enable_shared_expert_dp: bool,
|
||||
debug_layer_idx: int,
|
||||
first_k_dense_replace: int,
|
||||
tp_size: int,
|
||||
sfa_modules: AscendSFAModules,
|
||||
num_local_heads: int,
|
||||
scaling: float,
|
||||
layers: int,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
qk_nope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
self.debug_layer_idx = debug_layer_idx
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.tp_size = tp_size
|
||||
self.num_local_heads = num_local_heads
|
||||
self.layers = layers
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=sfa_modules.rotary_emb,
|
||||
q_a_proj=sfa_modules.q_a_proj,
|
||||
q_a_layernorm=sfa_modules.q_a_layernorm,
|
||||
q_proj=sfa_modules.q_proj,
|
||||
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=sfa_modules.kv_a_layernorm,
|
||||
kv_b_proj=sfa_modules.kv_b_proj,
|
||||
o_proj=sfa_modules.o_proj,
|
||||
indexer=sfa_modules.indexer)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
|
||||
self.prefix)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
def sfa_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if forward_context.attn_metadata:
|
||||
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
|
||||
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
return
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
dim: int = 7168,
|
||||
n_heads: int = 64,
|
||||
head_dim: int = 128,
|
||||
index_topk: int = 2048,
|
||||
q_lora_rank: int = 1536,
|
||||
rope_head_dim: int = 64,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = ""):
|
||||
super().__init__()
|
||||
|
||||
self.dim: int = dim # 7168
|
||||
self.n_heads: int = n_heads # 64
|
||||
self.head_dim: int = head_dim # 128
|
||||
self.rope_head_dim: int = rope_head_dim # 64
|
||||
self.index_topk: int = index_topk # 2048
|
||||
self.q_lora_rank: int = q_lora_rank # 1536
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
return_bias=False,
|
||||
)
|
||||
self.wk = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk",
|
||||
return_bias=False,
|
||||
)
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
def forward(self):
|
||||
return
|
||||
|
||||
|
||||
def sfa_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sfa_forward",
|
||||
op_func=sfa_forward,
|
||||
mutates_args=["output"],
|
||||
fake_impl=sfa_forward_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
54
vllm_npu/models/qwen2_5_omni_thinker.py
Normal file
54
vllm_npu/models/qwen2_5_omni_thinker.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import \
|
||||
Qwen2_5OmniThinkerConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.qwen2_5_omni_thinker import (
|
||||
Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||
Qwen2_5OmniThinkerMultiModalProcessor, Qwen2_5OmniThinkerProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_npu.models.qwen2_5_vl import AscendQwen2_5_VisionTransformer
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||
info=Qwen2_5OmniThinkerProcessingInfo,
|
||||
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder)
|
||||
class AscendQwen2_5OmniThinkerForConditionalGeneration(
|
||||
Qwen2_5OmniThinkerForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5OmniThinkerConfig = vllm_config.model_config.hf_config.thinker_config
|
||||
quant_config = vllm_config.quant_config
|
||||
# The following code reuse AscendQwen2_5_VisionTransformer from Qwen2_5_VL,
|
||||
# which does not import any model strcut difference. And will not impact
|
||||
# the modeling files removing.
|
||||
self.visual = AscendQwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
628
vllm_npu/models/qwen2_5_vl.py
Normal file
628
vllm_npu/models/qwen2_5_vl.py
Normal file
@@ -0,0 +1,628 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
|
||||
Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
q, k, v = [
|
||||
rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer)
|
||||
|
||||
context_layer = rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding):
|
||||
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__(dim, theta)
|
||||
inv_freq = 1.0 / (theta
|
||||
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.inv_freq = inv_freq
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
interleaved=False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
self.interleaved = interleaved
|
||||
self.enable_pad = False
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
|
||||
2)
|
||||
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=act_fn,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.enable_pad = True
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
|
||||
self.half_pad_hidden_size_per_attention_head = (
|
||||
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
if self.enable_pad:
|
||||
cos = torch.nn.functional.pad(
|
||||
cos, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
sin = torch.nn.functional.pad(
|
||||
sin, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
if not self.interleaved:
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def pad_qkv_bias(self, bias):
|
||||
first_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head]
|
||||
second_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:]
|
||||
first_half_padded = torch.nn.functional.pad(
|
||||
first_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
second_half_padded = torch.nn.functional.pad(
|
||||
second_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
|
||||
bias_final = bias_padded.reshape(-1)
|
||||
return bias_final
|
||||
|
||||
def pad_qkv_weight(self, data):
|
||||
qkv_weight_first_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
|
||||
qkv_weight_second_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]
|
||||
|
||||
qkv_weight_first_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_first_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_second_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_second_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_padded = torch.cat(
|
||||
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
|
||||
dim=2)
|
||||
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
|
||||
|
||||
if is_enable_nz(qkv_weight_final.dtype):
|
||||
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
|
||||
qkv_weight_final)
|
||||
qkv_weight_final_copy = torch_npu.npu_format_cast(
|
||||
qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND)
|
||||
return qkv_weight_final_copy
|
||||
|
||||
return qkv_weight_final
|
||||
|
||||
def pad_proj_weight(self, data):
|
||||
out_weight = torch.nn.functional.pad(
|
||||
data.reshape(self.hidden_size, -1,
|
||||
self.half_origin_hidden_size_per_attention_head),
|
||||
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
|
||||
self.hidden_size, -1)
|
||||
|
||||
if is_enable_nz(out_weight.dtype):
|
||||
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
|
||||
out_weight_copy = torch_npu.npu_format_cast(
|
||||
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
|
||||
return out_weight_copy
|
||||
|
||||
return out_weight
|
||||
|
||||
def pad_qkv_weight_scale_offset(self, data):
|
||||
reshaped_data = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, 1)
|
||||
data1 = reshaped_data[:, :, :self.
|
||||
half_origin_hidden_size_per_attention_head, :]
|
||||
data2 = reshaped_data[:, :, self.
|
||||
half_origin_hidden_size_per_attention_head:, :]
|
||||
data1_paded = torch.nn.functional.pad(
|
||||
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
|
||||
0, 0, 0))
|
||||
data2_paded = torch.nn.functional.pad(
|
||||
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
|
||||
0, 0, 0))
|
||||
res = torch.cat([data1_paded, data2_paded], dim=2)
|
||||
res = res.reshape(-1, 1)
|
||||
return res
|
||||
|
||||
def pad_qkv_deq_scale_quant_bias(self, data):
|
||||
reshaped_data = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head)
|
||||
data1 = reshaped_data[:, :, :self.
|
||||
half_origin_hidden_size_per_attention_head]
|
||||
data2 = reshaped_data[:, :,
|
||||
self.half_origin_hidden_size_per_attention_head:]
|
||||
|
||||
data1_paded = torch.nn.functional.pad(
|
||||
data1, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
data2_paded = torch.nn.functional.pad(
|
||||
data2, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
res = torch.cat([data1_paded, data2_paded], dim=2)
|
||||
res = res.reshape(-1)
|
||||
return res
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("attn.qkv.", "attn.q.", "q"),
|
||||
("attn.qkv.", "attn.k.", "k"),
|
||||
("attn.qkv.", "attn.v.", "v"),
|
||||
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
|
||||
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
if self.enable_pad and shard_id == "v":
|
||||
if "attn.qkv.weight" in name:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
if "attn.qkv.bias" in name:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if ("attn.proj.weight_scale" in name or
|
||||
"attn.proj.weight_offset" in name) and self.enable_pad:
|
||||
continue
|
||||
elif ("attn.proj.deq_scale" in name
|
||||
or "attn.proj.quant_bias" in name) and self.enable_pad:
|
||||
continue
|
||||
elif ("attn.qkv.weight_scale" in name
|
||||
or "attn.qkv.weight_offset" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight_scale_offset(param.data)
|
||||
elif ("attn.qkv.deq_scale" in name
|
||||
or "attn.qkv.quant_bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
|
||||
elif ("attn.proj.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_proj_weight(param.data)
|
||||
elif ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
elif ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (self.window_size //
|
||||
self.spatial_merge_size // self.patch_size)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
||||
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
||||
vit_merger_window_size)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(
|
||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:,
|
||||
0]).cpu().to(torch.int32)
|
||||
|
||||
# patchify
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# windows attention
|
||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
device=x.device,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
|
||||
seq_len, _ = x.size()
|
||||
x = x.reshape(seq_len // self.spatial_merge_unit,
|
||||
self.spatial_merge_unit, -1)
|
||||
x = x[window_index, :, :]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
x = x[reverse_indices, :]
|
||||
return x
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
info=Qwen2_5_VLProcessingInfo,
|
||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||
class AscendQwen2_5_VLForConditionalGeneration(
|
||||
Qwen2_5_VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen2_5_VisionTransformer(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
from vllm.model_executor.models.utils import \
|
||||
_merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
is_multimodal=is_multimodal,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
)
|
||||
780
vllm_npu/models/qwen2_5_vl_without_padding.py
Normal file
780
vllm_npu/models/qwen2_5_vl_without_padding.py
Normal file
@@ -0,0 +1,780 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
|
||||
try:
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
|
||||
Qwen3VLConfig
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \
|
||||
Qwen3VLMoeConfig
|
||||
except ImportError:
|
||||
pass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||
get_act_and_mul_fn)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
|
||||
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
|
||||
Qwen2_5_VLProcessingInfo)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer,
|
||||
Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||
from vllm.model_executor.models.qwen3_vl_moe import (
|
||||
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo)
|
||||
except ImportError:
|
||||
Qwen3_VisionBlock = object
|
||||
Qwen3_VisionPatchEmbed = object
|
||||
Qwen3_VisionTransformer = object
|
||||
Qwen3VLDummyInputsBuilder = object
|
||||
Qwen3VLForConditionalGeneration = object
|
||||
Qwen3VLMultiModalProcessor = object
|
||||
Qwen3VLProcessingInfo = object
|
||||
Qwen3VLMoeForConditionalGeneration = object
|
||||
Qwen3VLMoeProcessingInfo = object
|
||||
from vllm.model_executor.models.utils import (WeightsMapper,
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_npu.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
q, k, v = [
|
||||
rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1.dev20250226
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
scale_value=self.hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer)
|
||||
|
||||
context_layer = rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
interleaved=False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
self.interleaved = interleaved
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
|
||||
2)
|
||||
self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen2_5_VisionBlock_Without_Padding(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=act_fn,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
|
||||
if not self.interleaved:
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (self.window_size //
|
||||
self.spatial_merge_size // self.patch_size)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
||||
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
||||
vit_merger_window_size)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(
|
||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:,
|
||||
0]).cpu().to(torch.int32)
|
||||
|
||||
# patchify
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# windows attention
|
||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
device=x.device,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
|
||||
seq_len, _ = x.size()
|
||||
x = x.reshape(seq_len // self.spatial_merge_unit,
|
||||
self.spatial_merge_unit, -1)
|
||||
x = x[window_index, :, :]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
x = x[reverse_indices, :]
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
|
||||
x = x + self.proj.bias
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionBlock(Qwen3_VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
|
||||
quant_config, prefix, use_data_parallel)
|
||||
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix,
|
||||
use_data_parallel)
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
self.patch_embed = AscendQwen3_VisionPatchEmbed(
|
||||
patch_size=self.patch_size,
|
||||
temporal_patch_size=self.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
grid_thw_tensor = torch.tensor(grid_thw,
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
|
||||
grid_thw_tensor[:, 0]).cpu().to(torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
|
||||
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(
|
||||
layer_num)
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
deepstack_merger_idx](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states] + deepstack_feature_lists,
|
||||
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
info=Qwen2_5_VLProcessingInfo,
|
||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||
class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
|
||||
Qwen2_5_VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
|
||||
vision_config=config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
is_multimodal=is_multimodal,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.visual.": "visual.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen3VLConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLMoeProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class AscendQwen3VLMoeForConditionalGeneration(
|
||||
Qwen3VLMoeForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.visual.": "visual.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.visual = AscendQwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
if self.use_deepstack:
|
||||
(
|
||||
deepstack_input_embeds,
|
||||
multimodal_embeddings,
|
||||
) = self._compute_deepstack_embeds(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
else:
|
||||
deepstack_input_embeds = None
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
is_multimodal=is_multimodal,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
)
|
||||
if deepstack_input_embeds is not None:
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
return inputs_embeds
|
||||
|
||||
def _compute_deepstack_embeds(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
is_multimodal: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
|
||||
|
||||
visual_lens = [len(x) for x in multimodal_embeddings]
|
||||
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
|
||||
|
||||
total_dim = multimodal_embeddings_cat.shape[-1]
|
||||
assert total_dim == self.visual_dim + self.multiscale_dim, \
|
||||
f"Total dimension mismatch: input {total_dim}, expected {self.visual_dim + self.multiscale_dim}"
|
||||
multimodal_embeddings_main = multimodal_embeddings_cat[
|
||||
..., :self.visual_dim]
|
||||
multimodal_embeddings_multiscale = multimodal_embeddings_cat[
|
||||
..., self.visual_dim:]
|
||||
|
||||
multimodal_embeddings = torch.split(multimodal_embeddings_main,
|
||||
visual_lens,
|
||||
dim=0)
|
||||
multimodal_embeddings_multiscale = torch.split(
|
||||
multimodal_embeddings_multiscale, visual_lens, dim=0)
|
||||
|
||||
deepstack_input_embeds = inputs_embeds.new_zeros(
|
||||
inputs_embeds.size(0),
|
||||
self.deepstack_num_level * inputs_embeds.size(1))
|
||||
|
||||
deepstack_input_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=deepstack_input_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings_multiscale,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
deepstack_input_embeds = deepstack_input_embeds.view(
|
||||
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
|
||||
deepstack_input_embeds = deepstack_input_embeds.permute(
|
||||
1, 0, 2).contiguous()
|
||||
|
||||
return deepstack_input_embeds, multimodal_embeddings
|
||||
369
vllm_npu/models/qwen2_vl.py
Normal file
369
vllm_npu/models/qwen2_vl.py
Normal file
@@ -0,0 +1,369 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import \
|
||||
Qwen2VLVisionConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed,
|
||||
Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz
|
||||
|
||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||
|
||||
|
||||
class AscendQwen2VisionAttention(Qwen2VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.cu_seqlens = None
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
self.cu_seqlens = cu_seqlens
|
||||
|
||||
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
]
|
||||
q = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
q, k, v = [
|
||||
rearrange(x, "b s h d -> (b s) h d").contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=self.cu_seqlens,
|
||||
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
out=context_layer)
|
||||
context_layer = rearrange(context_layer,
|
||||
"(b s) h d -> s b (h d)",
|
||||
b=batch_size).contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class AscendQwen2VisionBlock(Qwen2VisionBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(dim, num_heads, mlp_ratio, act_layer, norm_layer,
|
||||
quant_config, prefix)
|
||||
self.attn = AscendQwen2VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AscendQwen2VisionTransformer(Qwen2VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
interleaved=False,
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
|
||||
self.interleaved = interleaved
|
||||
self.enable_pad = False
|
||||
self.depth = vision_config.depth
|
||||
self.hidden_size = vision_config.embed_dim
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.patch_embed = AscendQwen2VisionPatchEmbed(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
embed_dim=vision_config.embed_dim,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
AscendQwen2VisionBlock(dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=vision_config.mlp_ratio,
|
||||
norm_layer=partial(nn.LayerNorm,
|
||||
eps=norm_eps),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
self.hidden_size, self.num_heads)
|
||||
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.enable_pad = True
|
||||
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
|
||||
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
|
||||
self.half_pad_hidden_size_per_attention_head = (
|
||||
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def cal_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
|
||||
sin = rotary_pos_emb.sin()
|
||||
if self.enable_pad:
|
||||
cos = torch.nn.functional.pad(
|
||||
cos, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
sin = torch.nn.functional.pad(
|
||||
sin, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
if not self.interleaved:
|
||||
cos_new = torch.cat((cos, cos), dim=-1)
|
||||
sin_new = torch.cat((sin, sin), dim=-1)
|
||||
else:
|
||||
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
|
||||
"... d two -> ...(d two)",
|
||||
two=2)
|
||||
cos_new = cos_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
sin_new = sin_new.reshape(1, -1, 1,
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def pad_qkv_bias(self, bias):
|
||||
first_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head]
|
||||
second_half = bias.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:]
|
||||
first_half_padded = torch.nn.functional.pad(
|
||||
first_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
second_half_padded = torch.nn.functional.pad(
|
||||
second_half, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
|
||||
bias_final = bias_padded.reshape(-1)
|
||||
return bias_final
|
||||
|
||||
def pad_qkv_weight(self, data):
|
||||
qkv_weight_first_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
|
||||
qkv_weight_second_half = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
|
||||
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]
|
||||
|
||||
qkv_weight_first_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_first_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_second_half_padded = torch.nn.functional.pad(
|
||||
qkv_weight_second_half,
|
||||
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
|
||||
qkv_weight_padded = torch.cat(
|
||||
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
|
||||
dim=2)
|
||||
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
|
||||
|
||||
if is_enable_nz(qkv_weight_final.dtype):
|
||||
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
|
||||
qkv_weight_final)
|
||||
qkv_weight_final_copy = torch_npu.npu_format_cast(
|
||||
qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND)
|
||||
return qkv_weight_final_copy
|
||||
|
||||
return qkv_weight_final
|
||||
|
||||
def pad_proj_weight(self, data):
|
||||
out_weight = torch.nn.functional.pad(
|
||||
data.reshape(self.hidden_size, -1,
|
||||
self.half_origin_hidden_size_per_attention_head),
|
||||
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
|
||||
self.hidden_size, -1)
|
||||
|
||||
if is_enable_nz(out_weight.dtype):
|
||||
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
|
||||
out_weight_copy = torch_npu.npu_format_cast(
|
||||
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
|
||||
return out_weight_copy
|
||||
|
||||
return out_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if ("attn.proj.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_proj_weight(param.data)
|
||||
if ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
if ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
# compute cu_seqlens and avoid cumsum to fit operator unpadFA
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
grid_thw[:,
|
||||
0]).cpu().to(torch.int32)
|
||||
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
cos, sin = self.cal_cos_sin(rotary_pos_emb)
|
||||
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, cos=cos, sin=sin)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
return x
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
|
||||
info=Qwen2VLProcessingInfo,
|
||||
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
||||
class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.visual = AscendQwen2VisionTransformer(
|
||||
self.config.vision_config,
|
||||
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
676
vllm_npu/models/qwen3_next.py
Normal file
676
vllm_npu/models/qwen3_next.py
Normal file
@@ -0,0 +1,676 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
"""Inference-only Qwen3Next model."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from vllm import envs
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
VllmConfig, get_current_vllm_config)
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import RMSNormGated
|
||||
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
|
||||
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
|
||||
fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.layernorm import \
|
||||
GemmaRMSNorm as Qwen3NextRMSNorm
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import \
|
||||
mamba_v2_sharded_weight_loader
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
from vllm.model_executor.models.qwen3_next import ( # isort: skip
|
||||
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
|
||||
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
|
||||
fused_gdn_gating)
|
||||
|
||||
|
||||
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
|
||||
return GDNAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
|
||||
self.model_config.dtype, self.cache_config.mamba_cache_dtype)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
|
||||
self.head_v_dim, self.conv_kernel_size, self.num_spec)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_v_heads = config.linear_num_value_heads
|
||||
self.num_k_heads = config.linear_num_key_heads
|
||||
self.head_k_dim = config.linear_key_head_dim
|
||||
self.head_v_dim = config.linear_value_head_dim
|
||||
self.key_dim = self.head_k_dim * self.num_k_heads
|
||||
self.value_dim = self.head_v_dim * self.num_v_heads
|
||||
|
||||
self.conv_kernel_size = config.linear_conv_kernel_dim
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.activation = config.hidden_act
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.layer_norm_epsilon = config.rms_norm_eps
|
||||
self.prefix = prefix
|
||||
|
||||
self.config = config
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.speculative_config = speculative_config
|
||||
self.num_spec = (self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
|
||||
# QKV
|
||||
self.conv_dim = self.key_dim * 2 + self.value_dim
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=self.conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
# projection of the input hidden states
|
||||
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
||||
self.projection_size_ba = self.num_v_heads * 2
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
query_key_settings = (self.key_dim, 0, False)
|
||||
value_settings = (self.value_dim, 0, False)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight, {
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader([
|
||||
query_key_settings,
|
||||
query_key_settings,
|
||||
value_settings,
|
||||
], self.tp_size, self.tp_rank)
|
||||
})
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
|
||||
# time step projection (discretization)
|
||||
# instantiate once and copy inv_dt in init_weights of PretrainedModel
|
||||
self.dt_bias = nn.Parameter(
|
||||
torch.ones(self.num_v_heads // self.tp_size), )
|
||||
self.A_log = nn.Parameter(
|
||||
torch.empty(
|
||||
divide(self.num_v_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
))
|
||||
|
||||
set_weight_attrs(self.A_log,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
set_weight_attrs(self.dt_bias,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.norm = RMSNormGated(
|
||||
self.head_v_dim,
|
||||
eps=self.layer_norm_epsilon,
|
||||
norm_before_gate=True,
|
||||
device="npu",
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(self.value_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj")
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
spec_token_masks = attn_metadata.spec_token_masks
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
|
||||
num_actual_tokens = (attn_metadata.num_prefill_tokens +
|
||||
attn_metadata.num_decode_tokens +
|
||||
attn_metadata.num_spec_decode_tokens)
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
projected_states_qkvz, projected_states_ba = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
self.projection_size_qkvz // self.tp_size,
|
||||
self.projection_size_ba // self.tp_size
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba)
|
||||
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
||||
(query, key, value))
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if (attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes == 0):
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
mixed_qkv_non_spec = None
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||
else:
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||
mixed_qkv_non_spec.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
).transpose(0, 1)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv_non_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:attn_metadata
|
||||
.num_decodes],
|
||||
# validate_data=True,
|
||||
)
|
||||
else:
|
||||
mixed_qkv_non_spec = None
|
||||
|
||||
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_spec)
|
||||
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
|
||||
mixed_qkv_non_spec)
|
||||
|
||||
beta = b.sigmoid()
|
||||
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
|
||||
g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta))
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if (attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes == 0):
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g[:, spec_token_masks]
|
||||
beta_spec = beta[:, spec_token_masks]
|
||||
g_non_spec = g[:, ~spec_token_masks]
|
||||
beta_non_spec = beta[:, ~spec_token_masks]
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
|
||||
# 3. Recurrent attention
|
||||
# 3.1: process the mutlti-query part
|
||||
if spec_sequence_masks is not None:
|
||||
core_attn_out_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
g=g_spec,
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[:attn_metadata.
|
||||
num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
))
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
|
||||
# 3.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
initial_state = ssm_state[
|
||||
non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
|
||||
batch_size = initial_state.shape[0]
|
||||
core_attn_out = []
|
||||
last_recurrent_state = []
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
start, end = non_spec_query_start_loc[
|
||||
b_idx], non_spec_query_start_loc[b_idx + 1]
|
||||
cur_q = query_non_spec[:, start:end, ...]
|
||||
cur_k = key_non_spec[:, start:end, ...]
|
||||
cur_v = value_non_spec[:, start:end, ...]
|
||||
cur_g = g_non_spec[:, start:end, ...]
|
||||
cur_b = beta_non_spec[:, start:end, ...]
|
||||
cur_state = initial_state[b_idx].unsqueeze(0)
|
||||
|
||||
(
|
||||
cur_core_attn_out_non_spec,
|
||||
cur_last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(
|
||||
query=cur_q,
|
||||
key=cur_k,
|
||||
value=cur_v,
|
||||
g=cur_g,
|
||||
beta=cur_b,
|
||||
initial_state=cur_state,
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
core_attn_out.append(cur_core_attn_out_non_spec)
|
||||
last_recurrent_state.append(cur_last_recurrent_state)
|
||||
|
||||
tar_dtype = core_attn_out[0].dtype
|
||||
tar_device = core_attn_out[0].device
|
||||
tar_shape = list(core_attn_out[0].shape)
|
||||
tar_shape[1] = non_spec_query_start_loc[-1]
|
||||
core_attn_out_non_spec = torch.empty(tar_shape,
|
||||
dtype=tar_dtype,
|
||||
device=tar_device)
|
||||
for b_idx in range(batch_size):
|
||||
cur_core_attn_out = core_attn_out[b_idx]
|
||||
start, end = non_spec_query_start_loc[
|
||||
b_idx], non_spec_query_start_loc[b_idx + 1]
|
||||
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
|
||||
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
|
||||
|
||||
# Init cache
|
||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
||||
ssm_state.dtype)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[:attn_metadata.
|
||||
num_decodes + 1],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
))
|
||||
else:
|
||||
core_attn_out_non_spec, last_recurrent_state = None, None
|
||||
|
||||
# Merge core attention output
|
||||
if (spec_sequence_masks is not None
|
||||
and core_attn_out_non_spec is not None):
|
||||
core_attn_out = torch.empty(
|
||||
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
|
||||
dtype=core_attn_out_non_spec.dtype,
|
||||
device=core_attn_out_non_spec.device,
|
||||
)
|
||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||
elif spec_sequence_masks is not None:
|
||||
core_attn_out = core_attn_out_spec
|
||||
else:
|
||||
core_attn_out = core_attn_out_non_spec
|
||||
|
||||
z_shape_og = z.shape
|
||||
# reshape input data into 2D tensor
|
||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||
core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)')
|
||||
|
||||
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
|
||||
|
||||
|
||||
class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
layer_type: str,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = CustomQwen3NextGatedDeltaNet(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
speculative_config=speculative_config,
|
||||
prefix=f'{prefix}.linear_attn')
|
||||
elif self.layer_type == "full_attention":
|
||||
self.self_attn = Qwen3NextAttention(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.self_attn',
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid layer_type {self.layer_type}")
|
||||
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
if (self.layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3NextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen3NextRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.layer_scale = getattr(config, "layer_scale", False)
|
||||
if self.layer_scale:
|
||||
self.attn_layer_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3NextModel(Qwen3NextModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config: Qwen3NextConfig = vllm_config.model_config.hf_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
lora_config = vllm_config.lora_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
def get_layer(prefix: str):
|
||||
return CustomQwen3NextDecoderLayer(
|
||||
vllm_config,
|
||||
layer_type=config.layer_types[extract_layer_index(prefix)],
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("in_proj", "in_proj_qkvz", 0),
|
||||
("in_proj", "in_proj_ba", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if name.startswith("mtp."):
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# name = apply_attn_prefix(name, params_dict)
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen3Next currently does not support prefix caching"
|
||||
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = CustomQwen3NextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
self.num_logical_experts = example_layer.n_logical_experts
|
||||
self.num_physical_experts = example_layer.n_physical_experts
|
||||
self.num_local_physical_experts = example_layer.n_local_physical_experts
|
||||
self.num_routed_experts = example_layer.n_routed_experts
|
||||
self.num_redundant_experts = example_layer.n_redundant_experts
|
||||
0
vllm_npu/multistream/__init__.py
Normal file
0
vllm_npu/multistream/__init__.py
Normal file
29
vllm_npu/multistream/base.py
Normal file
29
vllm_npu/multistream/base.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MSEventKey(Enum):
|
||||
ATTN_COM_FINISH = 0
|
||||
ATTN_AR_FINISH = 1
|
||||
FFN_COM_FINISH = 2
|
||||
FFN_AR_FINISH = 3
|
||||
# events for MOE dispatch and combine
|
||||
MOE_BEFORE_COMM = 4
|
||||
MOE_AFTER_COMM = 5
|
||||
# events for shared expert
|
||||
MOE_SE_COMM_FINISH = 6
|
||||
MOE_SE_COMP_FINISH = 7
|
||||
MOE_GATE_FINISH = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class MSAttentionMetadataSplitConfig:
|
||||
"""
|
||||
micro batch split config for split attention metadata
|
||||
"""
|
||||
# micro batch num
|
||||
num_micro_batches: int = 2
|
||||
# split micro batches only when total tokens >= min_total_tokens_to_split
|
||||
min_total_tokens_to_split: int = 256
|
||||
# split micro batches only when prefill tokens >= min_prefill_tokens_to_split
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
67
vllm_npu/multistream/context.py
Normal file
67
vllm_npu/multistream/context.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_ms_comm_context: Any = None
|
||||
_cur_micro_batch_num: int = -1
|
||||
_ms_layer_index_context: int = -1
|
||||
_ms_metadata_context: Any = None
|
||||
_ms_attn_metadata_context: Any = None
|
||||
|
||||
|
||||
def set_multistream_layer_context(start_layer: int, ms_metadata: Any,
|
||||
attn_metadata: Any):
|
||||
"""
|
||||
set multistream layer context before transformer layers
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = start_layer
|
||||
_ms_metadata_context = ms_metadata
|
||||
_ms_attn_metadata_context = attn_metadata
|
||||
|
||||
|
||||
def reset_multistream_layer_context():
|
||||
"""
|
||||
reset multistream layer context
|
||||
"""
|
||||
global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
_ms_layer_index_context = -1
|
||||
_ms_metadata_context = None
|
||||
_ms_attn_metadata_context = None
|
||||
|
||||
|
||||
def get_multistream_layer_context():
|
||||
"""
|
||||
get multistream layer context
|
||||
"""
|
||||
return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context
|
||||
|
||||
|
||||
def advance_step_multistream_layer_context():
|
||||
"""
|
||||
advance multistream layer index context
|
||||
"""
|
||||
global _ms_layer_index_context
|
||||
_ms_layer_index_context += 1
|
||||
|
||||
|
||||
def get_multistream_comm_context() -> Any:
|
||||
"""Get the current comm forward context."""
|
||||
return _ms_comm_context
|
||||
|
||||
|
||||
def get_multistream_microbatch_context() -> int:
|
||||
return _cur_micro_batch_num
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_multistream_context(context: Any, micro_batch_num: int):
|
||||
"""A context manager that stores the current comm forward context,
|
||||
can be attention metadata, etc."""
|
||||
global _ms_comm_context, _cur_micro_batch_num
|
||||
_ms_comm_context = context
|
||||
_cur_micro_batch_num = micro_batch_num
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_ms_comm_context = None
|
||||
_cur_micro_batch_num = -1
|
||||
22
vllm_npu/multistream/decorator.py
Normal file
22
vllm_npu/multistream/decorator.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .context import (get_multistream_layer_context,
|
||||
get_multistream_microbatch_context)
|
||||
|
||||
|
||||
# vllm v1 use get_forward_context to get the attn_metadata,
|
||||
# we can use this decorator to update the attn metadata
|
||||
def set_multistream_support():
|
||||
|
||||
def decorator(func):
|
||||
|
||||
def wrapper():
|
||||
context = func()
|
||||
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
micro_batch_num = get_multistream_microbatch_context()
|
||||
if layer_index != -1 and micro_batch_num != -1:
|
||||
context.attn_metadata = attn_metadata[micro_batch_num]
|
||||
return context
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
61
vllm_npu/multistream/layers.py
Normal file
61
vllm_npu/multistream/layers.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from .base import MSEventKey
|
||||
from .context import (get_multistream_layer_context,
|
||||
reset_multistream_layer_context,
|
||||
set_multistream_layer_context)
|
||||
from .metadata import MultiStreamMetadata
|
||||
|
||||
|
||||
class MultiStreamPreTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
intput_tensors: List[torch.Tensor],
|
||||
):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if self.multistream_metadata is None or attn_metadata is None:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
# TODO add attn_metadata management
|
||||
do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch(
|
||||
attn_metadata, intput_tensors)
|
||||
if do_ms:
|
||||
set_multistream_layer_context(
|
||||
self.multistream_metadata.start_layer,
|
||||
self.multistream_metadata, attn_metadata)
|
||||
else:
|
||||
set_multistream_layer_context(-1, None, None)
|
||||
return attn_metadata, intput_tensors
|
||||
|
||||
|
||||
class MultiStreamPostTransformerLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, multistream_metadata: MultiStreamMetadata):
|
||||
super().__init__()
|
||||
self.multistream_metadata = multistream_metadata
|
||||
|
||||
def forward(self,
|
||||
input_tensors: Union[List[Tuple[torch.Tensor]],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]]],
|
||||
wait_layer_index: Optional[int] = None):
|
||||
if self.multistream_metadata is None or self.multistream_metadata.ms_config is None:
|
||||
return input_tensors
|
||||
layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context(
|
||||
)
|
||||
if layer_index >= 0:
|
||||
true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index
|
||||
self.multistream_metadata.try_wait_event(
|
||||
true_wait_layer,
|
||||
self.multistream_metadata.ms_config.num_micro_batches - 1,
|
||||
MSEventKey.FFN_AR_FINISH)
|
||||
reset_multistream_layer_context()
|
||||
return self.multistream_metadata.merge_micro_batches(input_tensors)
|
||||
182
vllm_npu/multistream/metadata.py
Normal file
182
vllm_npu/multistream/metadata.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_npu.attention.mla_v1 import AscendMLAMetadata
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig, MSEventKey
|
||||
|
||||
|
||||
def split_micro_batches_tensors(input_tensors,
|
||||
split_index: int,
|
||||
keys: Optional[List[str]] = None):
|
||||
if isinstance(input_tensors, list):
|
||||
micro_batches = []
|
||||
for tensor in input_tensors:
|
||||
if tensor is None:
|
||||
micro_batches.append([None, None])
|
||||
else:
|
||||
micro_batches.append(
|
||||
[tensor[:split_index], tensor[split_index:]])
|
||||
return micro_batches
|
||||
elif isinstance(input_tensors, torch.Tensor):
|
||||
return [input_tensors[:split_index], input_tensors[split_index:]]
|
||||
elif input_tensors is None:
|
||||
return [None, None]
|
||||
elif isinstance(input_tensors, Dict):
|
||||
assert keys is not None
|
||||
micro_batches_pre = {}
|
||||
for key in keys:
|
||||
micro_batches_pre[key] = input_tensors[key][:split_index]
|
||||
micro_batches_post = {}
|
||||
for key in keys:
|
||||
micro_batches_post[key] = input_tensors[key][split_index:]
|
||||
return [micro_batches_pre, micro_batches_post]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamStepMetadata:
|
||||
comm_stream: torch.npu.Stream = None
|
||||
before_comm_event: torch.npu.Event = None
|
||||
after_comm_event: torch.npu.Event = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStreamConfig:
|
||||
"""Controls the behavior of multi-stream models."""
|
||||
min_total_tokens_to_split: int = 256
|
||||
min_prefill_tokens_to_split: int = 64
|
||||
num_micro_batches: int = 2
|
||||
imbalance_ratio: float = 0.1
|
||||
|
||||
|
||||
class MultiStreamMetadata:
|
||||
# direct stream
|
||||
calculate_stream = None
|
||||
# delay stream
|
||||
communicate_stream = None
|
||||
# events
|
||||
ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {}
|
||||
# multi-stream-flag
|
||||
enable_multi_stream: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculate_stream: torch.npu.Stream,
|
||||
communicate_stream: torch.npu.Stream,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
event_keys: List[MSEventKey],
|
||||
multistream_config: Optional[MultiStreamConfig],
|
||||
causal_lm: bool = True,
|
||||
):
|
||||
self.calculate_stream = calculate_stream
|
||||
self.communicate_stream = communicate_stream
|
||||
self.start_layer = start_layer
|
||||
self.end_layer = end_layer
|
||||
self.ms_config = multistream_config
|
||||
self.causal_lm = causal_lm
|
||||
self._build_events(event_keys)
|
||||
self._build_ms_split_config()
|
||||
|
||||
def _build_events(self, event_keys):
|
||||
if self.ms_config is not None:
|
||||
for i in range(self.start_layer - 1, self.end_layer):
|
||||
self.ms_events[i] = {}
|
||||
for j in range(self.ms_config.num_micro_batches):
|
||||
self.ms_events[i][j] = {}
|
||||
for key in event_keys:
|
||||
self.ms_events[i][j][key] = torch.npu.Event()
|
||||
|
||||
def _build_ms_split_config(self):
|
||||
if self.ms_config is not None:
|
||||
self.ms_split_config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=self.ms_config.num_micro_batches,
|
||||
min_total_tokens_to_split=self.ms_config.
|
||||
min_total_tokens_to_split,
|
||||
min_prefill_tokens_to_split=self.ms_config.
|
||||
min_prefill_tokens_to_split,
|
||||
)
|
||||
|
||||
def try_wait_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].wait()
|
||||
|
||||
def try_record_event(self, layer_index: int, micro_batch_index: int,
|
||||
event_key: MSEventKey):
|
||||
self.ms_events[layer_index][micro_batch_index][event_key].record()
|
||||
|
||||
def split_micro_batch(
|
||||
self,
|
||||
attn_metadata: "AscendMLAMetadata",
|
||||
intput_tensors: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
intermediate_tensors_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[
|
||||
List[torch.Tensor], List[List[torch.Tensor]]], Union[
|
||||
IntermediateTensors, List[IntermediateTensors]]]:
|
||||
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
|
||||
self.ms_split_config)
|
||||
if len(attn_metadata_list) == 1:
|
||||
return False, attn_metadata_list[
|
||||
0], intput_tensors, intermediate_tensors
|
||||
split_index = attn_metadata_list[0].slot_mapping.shape[0]
|
||||
input_tensors = split_micro_batches_tensors(intput_tensors,
|
||||
split_index)
|
||||
if intermediate_tensors is not None:
|
||||
inter_tensors_list = split_micro_batches_tensors(
|
||||
intermediate_tensors.tensors, split_index,
|
||||
intermediate_tensors_keys)
|
||||
intermediate_tensors = [
|
||||
IntermediateTensors(inter_tensors)
|
||||
for inter_tensors in inter_tensors_list
|
||||
]
|
||||
return True, attn_metadata_list, input_tensors, intermediate_tensors
|
||||
|
||||
def merge_micro_batches(
|
||||
self, input_tensors: Union[List[torch.Tensor],
|
||||
List[List[torch.Tensor]]]
|
||||
) -> List[torch.Tensor]:
|
||||
if input_tensors is None or isinstance(input_tensors[0], torch.Tensor):
|
||||
return input_tensors
|
||||
batch: List[Optional[torch.Tensor]] = []
|
||||
for tensors in input_tensors:
|
||||
if tensors is None or tensors[0] is None:
|
||||
batch.append(None)
|
||||
else:
|
||||
batch.append(torch.cat(tensors, dim=0))
|
||||
return batch
|
||||
|
||||
|
||||
def make_multistream_metadata_ds(
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
causal_lm: bool = True,
|
||||
multistream_config: Optional[MultiStreamConfig] = None,
|
||||
):
|
||||
if multistream_config is None:
|
||||
return None
|
||||
event_keylist = [
|
||||
MSEventKey.ATTN_COM_FINISH,
|
||||
MSEventKey.ATTN_AR_FINISH,
|
||||
MSEventKey.FFN_COM_FINISH,
|
||||
MSEventKey.FFN_AR_FINISH,
|
||||
MSEventKey.MOE_BEFORE_COMM,
|
||||
MSEventKey.MOE_AFTER_COMM,
|
||||
MSEventKey.MOE_SE_COMM_FINISH,
|
||||
MSEventKey.MOE_SE_COMP_FINISH,
|
||||
MSEventKey.MOE_GATE_FINISH,
|
||||
]
|
||||
return MultiStreamMetadata(
|
||||
calculate_stream=torch.npu.current_stream(),
|
||||
communicate_stream=torch.npu.Stream(),
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
multistream_config=multistream_config,
|
||||
event_keys=event_keylist,
|
||||
causal_lm=causal_lm,
|
||||
)
|
||||
247
vllm_npu/multistream/ms_split.py
Normal file
247
vllm_npu/multistream/ms_split.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm_npu.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
from .base import MSAttentionMetadataSplitConfig
|
||||
|
||||
|
||||
def compute_split_seq_index(
|
||||
query_lens: Optional[list[int]],
|
||||
attn_state: AscendAttentionState,
|
||||
num_tokens: int,
|
||||
imbalance_ratio: float = 0.1,
|
||||
) -> list[int]:
|
||||
if attn_state != AscendAttentionState.DecodeOnly:
|
||||
assert query_lens is not None
|
||||
total_tokens = sum(query_lens)
|
||||
# the first index in last split
|
||||
tokens, split_index = 0, 0
|
||||
for value in query_lens:
|
||||
tokens += value
|
||||
split_index += 1
|
||||
if tokens >= total_tokens // 2:
|
||||
# check the current split index
|
||||
if abs(tokens -
|
||||
total_tokens // 2) < total_tokens * imbalance_ratio:
|
||||
return [tokens, split_index]
|
||||
# check the previous split index
|
||||
elif abs(tokens - total_tokens // 2 -
|
||||
value) < total_tokens * imbalance_ratio:
|
||||
return [tokens - value, split_index - 1]
|
||||
# fail to split if it is imbalanced
|
||||
# TODO: split tokens in seq
|
||||
else:
|
||||
return [0, 0]
|
||||
else:
|
||||
tokens = num_tokens // 2
|
||||
return [tokens, tokens]
|
||||
return [0, 0]
|
||||
|
||||
|
||||
def split_attn_tensor_type(
|
||||
input_tensor: torch.Tensor,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [input_tensor[:index], input_tensor[index:]]
|
||||
|
||||
|
||||
def split_attn_int_type(
|
||||
var: int,
|
||||
index: int,
|
||||
) -> List[torch.Tensor]:
|
||||
return [min(var, index), max(var - index, 0)]
|
||||
|
||||
|
||||
def model_input_split_v1_mla_attn(
|
||||
attn_metadata,
|
||||
_metadata_cls,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> List[Any]:
|
||||
assert 0 < ms_split_config.num_micro_batches < 3
|
||||
if attn_metadata is None:
|
||||
return [attn_metadata]
|
||||
[token_index,
|
||||
seq_index] = compute_split_seq_index(attn_metadata.query_lens,
|
||||
attn_metadata.attn_state,
|
||||
attn_metadata.num_decode_tokens)
|
||||
if token_index == 0 or seq_index == 0 or seq_index == len(
|
||||
attn_metadata.query_lens):
|
||||
return [attn_metadata]
|
||||
|
||||
query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ),
|
||||
dtype=int)
|
||||
np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:])
|
||||
if attn_metadata.num_prefills > 0:
|
||||
prefill_query_start_loc = np.zeros(
|
||||
shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int)
|
||||
np.cumsum(attn_metadata.prefill.query_lens,
|
||||
out=prefill_query_start_loc[1:])
|
||||
|
||||
# split attn metadata
|
||||
[slot_mapping_pre,
|
||||
slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping,
|
||||
token_index)
|
||||
[num_decodes_pre,
|
||||
num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes,
|
||||
seq_index)
|
||||
[num_decode_tokens_pre, num_decode_tokens_post
|
||||
] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index)
|
||||
[num_prefills_pre, num_prefills_post
|
||||
] = split_attn_int_type(attn_metadata.num_prefills,
|
||||
max(0, seq_index - attn_metadata.num_decodes))
|
||||
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
|
||||
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
|
||||
query_start_loc_pre = query_start_loc_post = None
|
||||
if attn_metadata.query_start_loc is not None:
|
||||
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
|
||||
query_start_loc_post = deepcopy(
|
||||
attn_metadata.query_start_loc[seq_index:]
|
||||
) - attn_metadata.query_start_loc[seq_index]
|
||||
[block_table_pre,
|
||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||
seq_index)
|
||||
assert attn_metadata.attn_mask is not None
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = attn_metadata.attn_state
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
# should be none in decode only state
|
||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
# chunked prefill
|
||||
if num_prefills_pre > 0:
|
||||
attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_pre = attn_metadata.attn_mask[:token_index, :max(
|
||||
seq_lens_pre)].contiguous()
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
else:
|
||||
attn_state_pre = AscendAttentionState.DecodeOnly
|
||||
attn_mask_pre = None
|
||||
attn_state_post = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask_post = attn_metadata.attn_mask[
|
||||
token_index:, :max(seq_lens_post)].contiguous()
|
||||
from vllm_npu.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
AscendMLAPrefillMetadata)
|
||||
if num_prefills_pre > 0:
|
||||
# split metadata.prefill
|
||||
[input_positions_pre, input_positions_post] = split_attn_tensor_type(
|
||||
attn_metadata.prefill.input_positions,
|
||||
token_index - attn_metadata.num_decode_tokens)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.block_table,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
[prefill_query_lens_pre, prefill_query_lens_post
|
||||
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
|
||||
seq_index - attn_metadata.num_decodes)
|
||||
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
|
||||
seq_index
|
||||
+
|
||||
1 -
|
||||
attn_metadata
|
||||
.
|
||||
num_decodes]
|
||||
prefill_query_start_loc_post = deepcopy(
|
||||
attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes:]
|
||||
) - attn_metadata.prefill.query_start_loc[seq_index -
|
||||
attn_metadata.num_decodes]
|
||||
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
|
||||
context_len_post = seq_lens_post
|
||||
prefill_max_query_len_pre = max(prefill_query_lens_pre)
|
||||
prefill_max_query_len_post = max(prefill_query_lens_post)
|
||||
prefill_pre = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_pre,
|
||||
query_lens=prefill_query_lens_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=prefill_query_start_loc_pre,
|
||||
input_positions=input_positions_pre,
|
||||
context_lens=context_len_pre,
|
||||
block_table=block_tables_pre,
|
||||
max_query_len=prefill_max_query_len_pre,
|
||||
max_seq_lens=context_len_pre.max().item(),
|
||||
)
|
||||
prefill_post = AscendMLAPrefillMetadata(
|
||||
attn_mask=attn_mask_post,
|
||||
query_lens=prefill_query_lens_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=prefill_query_start_loc_post,
|
||||
input_positions=input_positions_post,
|
||||
context_lens=context_len_post,
|
||||
block_table=block_tables_post,
|
||||
max_query_len=prefill_max_query_len_post,
|
||||
max_seq_lens=context_len_post.max().item(),
|
||||
)
|
||||
decode_pre = attn_metadata.decode
|
||||
decode_post = None
|
||||
else:
|
||||
# prefill is None, split metadata.decode
|
||||
[input_positions_pre, input_positions_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.input_positions,
|
||||
token_index)
|
||||
[block_tables_pre, block_tables_post
|
||||
] = split_attn_tensor_type(attn_metadata.decode.block_table,
|
||||
seq_index)
|
||||
[decode_seq_lens_pre,
|
||||
decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
|
||||
decode_pre = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_pre,
|
||||
block_table=block_tables_pre,
|
||||
seq_lens=decode_seq_lens_pre,
|
||||
max_seq_lens=max(decode_seq_lens_pre),
|
||||
seq_lens_list=decode_seq_lens_pre.tolist(),
|
||||
)
|
||||
decode_post = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions_post,
|
||||
block_table=block_tables_post,
|
||||
seq_lens=decode_seq_lens_post,
|
||||
max_seq_lens=max(decode_seq_lens_post),
|
||||
seq_lens_list=decode_seq_lens_post.tolist(),
|
||||
)
|
||||
prefill_pre = None
|
||||
prefill_post = attn_metadata.prefill
|
||||
# construct metadata
|
||||
from vllm_npu.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
attention_metadata_pre = _metadata_cls(
|
||||
num_actual_tokens=token_index,
|
||||
num_input_tokens=token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_pre,
|
||||
seq_lens=seq_lens_pre,
|
||||
query_start_loc=query_start_loc_pre,
|
||||
block_tables=block_table_pre,
|
||||
num_decodes=num_decodes_pre,
|
||||
num_prefills=num_prefills_pre,
|
||||
num_decode_tokens=num_decode_tokens_pre,
|
||||
attn_state=attn_state_pre,
|
||||
attn_mask=attn_mask_pre,
|
||||
prefill=prefill_pre,
|
||||
decode=decode_pre,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
attention_metadata_post = _metadata_cls(
|
||||
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
|
||||
num_input_tokens=attn_metadata.num_input_tokens - token_index,
|
||||
head_dim=attn_metadata.head_dim,
|
||||
slot_mapping=slot_mapping_post,
|
||||
seq_lens=seq_lens_post,
|
||||
query_start_loc=query_start_loc_post,
|
||||
block_tables=block_table_post,
|
||||
num_decodes=num_decodes_post,
|
||||
num_prefills=num_prefills_post,
|
||||
num_decode_tokens=num_decode_tokens_post,
|
||||
attn_mask=attn_mask_post,
|
||||
attn_state=attn_state_post,
|
||||
prefill=prefill_post,
|
||||
decode=decode_post,
|
||||
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
return [attention_metadata_pre, attention_metadata_post]
|
||||
@@ -1 +1,57 @@
|
||||
"""Ascend NPU custom op registrations."""
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
import vllm_npu.ops.common_fused_moe # noqa
|
||||
import vllm_npu.ops.layernorm # noqa
|
||||
import vllm_npu.ops.register_custom_ops # noqa
|
||||
import vllm_npu.ops.vocab_parallel_embedding # noqa
|
||||
from vllm_npu.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_npu.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||
|
||||
|
||||
class dummyFusionOp:
|
||||
default = None
|
||||
|
||||
def __init__(self, name=""):
|
||||
self.name = name
|
||||
|
||||
|
||||
def register_dummy_fusion_op() -> None:
|
||||
torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm")
|
||||
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(
|
||||
name="fused_add_rms_norm")
|
||||
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(
|
||||
name="static_scaled_fp8_quant")
|
||||
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_scaled_fp8_quant")
|
||||
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
||||
name="dynamic_per_token_scaled_fp8_quant")
|
||||
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="rms_norm_static_fp8_quant")
|
||||
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
||||
name="fused_add_rms_norm_static_fp8_quant")
|
||||
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
||||
name="rms_norm_dynamic_per_token_quant")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding",
|
||||
"AscendDeepseekScalingRotaryEmbedding"
|
||||
]
|
||||
|
||||
@@ -1,17 +1,44 @@
|
||||
"""
|
||||
NPU-optimized activation functions for Ascend.
|
||||
|
||||
Provides ``AscendSiluAndMul`` that uses ``torch_npu.npu_swiglu`` for
|
||||
fused SiLU+Mul on NPU devices.
|
||||
"""
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
|
||||
|
||||
class AscendQuickGELU(QuickGELU):
|
||||
|
||||
def forward_oot(self, x: torch.tensor) -> torch.Tensor:
|
||||
import torch_npu
|
||||
|
||||
out = torch_npu.npu_fast_gelu(x)
|
||||
return out
|
||||
|
||||
|
||||
class AscendSiluAndMul(SiluAndMul):
|
||||
"""SiluAndMul using torch_npu.npu_swiglu on Ascend NPU."""
|
||||
|
||||
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
import torch_npu # noqa: F401
|
||||
return torch_npu.npu_swiglu(x)
|
||||
import torch_npu
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
|
||||
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
||||
if is_310p():
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(out)
|
||||
return out
|
||||
|
||||
309
vllm_npu/ops/attention.py
Normal file
309
vllm_npu/ops/attention.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
|
||||
|
||||
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
|
||||
# all the corner case
|
||||
def vanilla_chunked_prefill(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor, # (num_tokens, heads, head_size)
|
||||
key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size)
|
||||
value_cache: torch.
|
||||
Tensor, # (num_blocks, block_size, kv_heads, head_size,)
|
||||
block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq)
|
||||
cu_seqlen_q: torch.Tensor, # (num_seqs + 1,)
|
||||
cu_seqlen_k: torch.Tensor, # (num_seqs + 1,)
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
num_query_heads = query.shape[1]
|
||||
head_dim = value_cache.shape[3]
|
||||
num_kv_heads = value_cache.shape[2]
|
||||
block_size = value_cache.shape[1]
|
||||
num_batch = cu_seqlen_q.shape[0] - 1
|
||||
max_num_blocks_per_seq = block_tables.shape[1]
|
||||
|
||||
key = key_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
|
||||
value = value_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
key = key[:, :max_seqlen_k, :, :]
|
||||
value = value[:, :max_seqlen_k, :, :]
|
||||
|
||||
seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
|
||||
seqlen_q = seqlen_q.view(-1, 1)
|
||||
seqlen_k = seqlen_k.view(-1, 1)
|
||||
seqlen_diff = seqlen_k - seqlen_q
|
||||
q_idx_mask = (torch.arange(0, max_seqlen_q,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
k_idx_mask = (torch.arange(0, max_seqlen_k,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
q_mask = q_idx_mask < seqlen_q
|
||||
k_mask = k_idx_mask < seqlen_k
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k,
|
||||
device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[num_batch, max_seqlen_q, num_query_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[k_mask] = key[k_mask]
|
||||
pad_v[k_mask] = value[k_mask]
|
||||
|
||||
if num_query_heads > num_kv_heads:
|
||||
pad_k = pad_k.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
pad_v = pad_v.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
# permute to [b, h, n, k]
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
|
||||
head_dim]).to(output.dtype))
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_chunked_prefill_mla(
|
||||
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
|
||||
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
|
||||
kv_cache: Tuple[
|
||||
torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv)
|
||||
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
|
||||
query_lens: torch.Tensor, # (batch_size)
|
||||
context_lens: torch.Tensor, # (batch_size)
|
||||
kv_b_proj: ColumnParallelLinear, # ()
|
||||
max_query_len: int,
|
||||
max_context_len: int,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
v_head_dim: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True) -> None:
|
||||
batch_size = block_tables.size(0)
|
||||
assert len(kv_cache) > 1
|
||||
assert query_lens.size(0) == batch_size
|
||||
num_heads = query.size(1)
|
||||
nope_cache = kv_cache[0]
|
||||
rope_cache = kv_cache[1]
|
||||
block_size = nope_cache.size(1)
|
||||
latent_kv_dim = nope_cache.size(-1)
|
||||
max_num_blocks_per_seq = block_tables.size(1)
|
||||
batch_size = query_lens.size(0)
|
||||
nope_cache = nope_cache.squeeze()
|
||||
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe
|
||||
# cached_kv_c: [batch_size, max_context_len, latent_kv]
|
||||
# cached_k_pe: [batch_size, max_context_len, rope_dim]
|
||||
cache_kv_c = nope_cache[block_tables].view(
|
||||
batch_size, max_num_blocks_per_seq * block_size,
|
||||
latent_kv_dim)[:, :max_context_len, :]
|
||||
cache_k_pe = rope_cache[block_tables].view(
|
||||
batch_size, max_num_blocks_per_seq * block_size,
|
||||
rope_dim)[:, :max_context_len, :]
|
||||
# get k_rope and v
|
||||
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
|
||||
# value: [batch_size, max_context_len, num_heads, v_head_dim]
|
||||
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
|
||||
batch_size, max_context_len, num_heads,
|
||||
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
|
||||
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
|
||||
key = torch.cat(
|
||||
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
|
||||
dim=-1)
|
||||
|
||||
context_lens = context_lens.view(-1, 1).to("npu")
|
||||
query_lens = query_lens.view(-1, 1).to("npu")
|
||||
seq_diff = context_lens - query_lens
|
||||
|
||||
q_idx_mask = (torch.arange(0, max_query_len,
|
||||
device="npu").view(1, -1).repeat(batch_size, 1))
|
||||
kv_c_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
kv_c_mask = kv_c_idx_mask < context_lens
|
||||
q_mask = q_idx_mask < query_lens
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(
|
||||
torch.ones(max_context_len, max_context_len, device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty(
|
||||
[batch_size, max_query_len, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, v_head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
num_query = torch.sum(q_mask).item()
|
||||
num_add_query = num_query - query.size(0)
|
||||
# mtp will come in
|
||||
if num_add_query > 0:
|
||||
add_query_size = query.size()
|
||||
add_query_size = list(add_query_size)
|
||||
add_query_size[0] = num_add_query
|
||||
pad_tensor = torch.zeros(add_query_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
query = torch.cat([query, pad_tensor], dim=0)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[kv_c_mask] = key[kv_c_mask]
|
||||
pad_v[kv_c_mask] = value[kv_c_mask]
|
||||
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_context_len].masked_fill_(
|
||||
kv_c_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_heads,
|
||||
v_head_dim]).to(output.dtype))
|
||||
attn_output = attn_output.view_as(output)
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_decode_mla(
|
||||
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
|
||||
key_cache: torch.
|
||||
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
|
||||
num_kv_heads: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
block_table: torch.Tensor, # [batch_size, max_block_size]
|
||||
context_lens: List[int],
|
||||
mla_vhead_size: int,
|
||||
rope_dim: int,
|
||||
output: torch.Tensor):
|
||||
batch_size = block_table.size()[0]
|
||||
max_block_size = block_table.size()[1]
|
||||
reduce_dim = key_cache.size()[-1]
|
||||
block_size = key_cache.size()[1]
|
||||
latent_dim = reduce_dim - rope_dim
|
||||
kv_c_and_pe = key_cache[block_table].view(
|
||||
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
|
||||
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
|
||||
# since the kv head is 1 in deepseek, we use expand here for perf
|
||||
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
|
||||
-1, -1, num_heads, 1)
|
||||
kv_c = kv_c_and_pe[..., :latent_dim]
|
||||
kv_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
# [batch_size, max_context_len]
|
||||
kv_idx_mask = kv_idx_mask < context_lens
|
||||
query = query.unsqueeze(1)
|
||||
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
|
||||
attn_weights *= scale
|
||||
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
|
||||
kv_c.float()).view(-1, num_heads, latent_dim)
|
||||
output.copy_(attn_output)
|
||||
return output
|
||||
539
vllm_npu/ops/casual_conv1d.py
Normal file
539
vllm_npu/ops/casual_conv1d.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
seqlens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
if cache_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]]
|
||||
if has_initial_state[i] else None))
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def _causal_conv1d_update_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, # (batch, dim, seqlen)
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr,
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
intermediate_conv_window_ptr,
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.constexpr,
|
||||
state_len: tl.constexpr,
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr,
|
||||
stride_x_dim: tl.constexpr,
|
||||
stride_x_token: tl.constexpr,
|
||||
stride_w_dim: tl.constexpr,
|
||||
stride_w_width: tl.constexpr,
|
||||
stride_conv_state_seq: tl.constexpr,
|
||||
stride_conv_state_dim: tl.constexpr,
|
||||
stride_conv_state_tok: tl.constexpr,
|
||||
stride_state_indices: tl.constexpr,
|
||||
stride_inter_seq: tl.constexpr,
|
||||
stride_inter_step: tl.constexpr,
|
||||
stride_inter_dim: tl.constexpr,
|
||||
stride_inter_win: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SAVE_INTERMEDIATE: tl.constexpr,
|
||||
):
|
||||
# ruff: noqa: E501
|
||||
idx_seq = tl.program_id(0)
|
||||
if idx_seq >= batch:
|
||||
return
|
||||
|
||||
# [BLOCK_N,] elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
if IS_SPEC_DECODING:
|
||||
# The rolling of conv state:
|
||||
#
|
||||
# Before forward, the conv_state is:
|
||||
# [history1, history2, ..., historyM].
|
||||
#
|
||||
# After forward, the conv_state becomes:
|
||||
# [history2, ..., historyM, draft1, draft2, ..., draftN].
|
||||
#
|
||||
# After acceptance, it becomes:
|
||||
#
|
||||
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
||||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||
# - and so on.
|
||||
conv_state_token_offset = tl.load(num_accepted_tokens_ptr +
|
||||
idx_seq) - 1
|
||||
else:
|
||||
conv_state_token_offset = 0
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
if KERNEL_WIDTH >= 2:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
if KERNEL_WIDTH == 5:
|
||||
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
|
||||
#col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||
|
||||
# STEP 2: assume state_len > seqlen
|
||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# The conv_state updates works in a sliding window manner,
|
||||
# at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
|
||||
) # [BLOCK_N]
|
||||
|
||||
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens - VAL >= 0)[:, None]
|
||||
& (idx_tokens - VAL < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
tl.debug_barrier()
|
||||
|
||||
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
||||
|
||||
conv_state_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
conv_state_ptrs_target = (conv_state_base +
|
||||
(idx_tokens * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||
|
||||
# STEP 3: init accumulator
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
acc_preload = tl.load(bias, mask=mask_bias,
|
||||
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||
else:
|
||||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
|
||||
# STEP 4:
|
||||
# PRE-LOAD WEIGHTS
|
||||
# first kernel column, configured for weights to handle BLOCK_N features in range
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
mask_w = idx_feats < dim
|
||||
if KERNEL_WIDTH >= 2:
|
||||
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||
|
||||
x_base_1d = x_base # starting of chunk [BLOCK_N]
|
||||
mask_x_1d = idx_feats < dim
|
||||
|
||||
# STEP 5: compute each token
|
||||
for idx_token in tl.static_range(seqlen):
|
||||
acc = acc_preload
|
||||
|
||||
matrix_w = w_col0
|
||||
matrix_x = col0
|
||||
for j in tl.static_range(KERNEL_WIDTH):
|
||||
if KERNEL_WIDTH == 2:
|
||||
if j == 1: # KERNEL_WIDTH-1:
|
||||
matrix_w = w_col1
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 3:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
elif KERNEL_WIDTH == 4:
|
||||
if j == 1:
|
||||
matrix_w = w_col1
|
||||
matrix_x = col1
|
||||
elif j == 2:
|
||||
matrix_w = w_col2
|
||||
matrix_x = col2
|
||||
elif j == 3:
|
||||
matrix_w = w_col3
|
||||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||
|
||||
acc += matrix_x * matrix_w # [BLOCK_N]
|
||||
|
||||
if KERNEL_WIDTH == 2:
|
||||
col0 = matrix_x
|
||||
elif KERNEL_WIDTH == 3:
|
||||
col0 = col1
|
||||
col1 = matrix_x
|
||||
elif KERNEL_WIDTH == 4:
|
||||
col0 = col1
|
||||
col1 = col2
|
||||
col2 = matrix_x
|
||||
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
# mask_1d = (idx_token < seqlen) & (
|
||||
# idx_feats < dim
|
||||
# ) # token-index # feature-index
|
||||
maskL = idx_feats < dim
|
||||
maskR = tl.full(maskL.shape, False, tl.int1)
|
||||
mask_1d = tl.where(idx_token < seqlen, maskL, maskR)
|
||||
|
||||
o_ptrs = (o_ptr + (idx_seq) * stride_o_seq +
|
||||
idx_token * stride_o_token + (idx_feats * stride_o_dim))
|
||||
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
if SAVE_INTERMEDIATE:
|
||||
# Save the window state after consuming this token
|
||||
# Layout: [seq(cache line), step, dim, win(K-1)]
|
||||
base_ptr = (intermediate_conv_window_ptr +
|
||||
conv_state_batch_coord * stride_inter_seq +
|
||||
idx_token * stride_inter_step +
|
||||
idx_feats * stride_inter_dim)
|
||||
if KERNEL_WIDTH >= 2:
|
||||
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
|
||||
if KERNEL_WIDTH >= 3:
|
||||
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
|
||||
if KERNEL_WIDTH >= 4:
|
||||
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
|
||||
|
||||
|
||||
def causal_conv1d_update_npu(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
intermediate_conv_window: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
[shape=2: single token prediction]
|
||||
[shape=3: single or multiple tokens prediction]
|
||||
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if validate_data:
|
||||
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||
assert pad_slot_id is not None
|
||||
assert x.stride(1) == 1
|
||||
if isinstance(activation, bool):
|
||||
activation = "silu" if activation is True else None
|
||||
elif activation is not None:
|
||||
assert activation in ["silu", "swish"]
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
_, width = weight.shape
|
||||
# conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
num_cache_lines, _, state_len = conv_state.size()
|
||||
|
||||
if validate_data:
|
||||
assert dim == weight.size(0)
|
||||
assert (
|
||||
conv_state.stride(-2) == 1
|
||||
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
||||
assert state_len >= width - 1
|
||||
# when above happens, we don't shift-left to keep any records in conv_state
|
||||
assert dim == conv_state.size(1)
|
||||
if conv_state_indices is None:
|
||||
assert conv_state.size(0) >= batch
|
||||
else:
|
||||
assert (batch, ) == conv_state_indices.shape
|
||||
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
||||
|
||||
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||
out = x
|
||||
stride_w_dim, stride_w_width = weight.stride()
|
||||
|
||||
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
|
||||
) # X (batch, dim, seqlen)
|
||||
|
||||
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
)
|
||||
stride_state_indices = (conv_state_indices.stride(0)
|
||||
if conv_state_indices is not None else 0)
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
batch,
|
||||
triton.cdiv(dim, META["BLOCK_N"]),
|
||||
)
|
||||
|
||||
# prepare intermediate buffer strides if provided
|
||||
if intermediate_conv_window is not None:
|
||||
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
|
||||
intermediate_conv_window.stride(0),
|
||||
intermediate_conv_window.stride(1),
|
||||
intermediate_conv_window.stride(2),
|
||||
intermediate_conv_window.stride(3),
|
||||
)
|
||||
else:
|
||||
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
|
||||
|
||||
_causal_conv1d_update_kernel[grid](
|
||||
# Pointers to matrices
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
conv_state,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
intermediate_conv_window
|
||||
if intermediate_conv_window is not None else x,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
dim,
|
||||
seqlen,
|
||||
state_len,
|
||||
num_cache_lines,
|
||||
# stride
|
||||
stride_x_seq,
|
||||
stride_x_dim,
|
||||
stride_x_token,
|
||||
stride_w_dim,
|
||||
stride_w_width,
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_state_indices,
|
||||
stride_inter_seq,
|
||||
stride_inter_step,
|
||||
stride_inter_dim,
|
||||
stride_inter_win,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
# others
|
||||
pad_slot_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
BLOCK_N=128,
|
||||
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
|
||||
)
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
451
vllm_npu/ops/common_fused_moe.py
Normal file
451
vllm_npu/ops/common_fused_moe.py
Normal file
@@ -0,0 +1,451 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os.path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map,
|
||||
get_compressed_expert_map)
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.ascend_forward_context import MoECommType
|
||||
from vllm_npu.distributed.parallel_state import get_mc2_group
|
||||
from vllm_npu.eplb.core.eplb_utils import determine_default_log2phy_map
|
||||
from vllm_npu.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_npu.ops.moe.experts_selector import select_experts
|
||||
from vllm_npu.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_npu.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
from vllm_npu.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
||||
is_enable_nz, npu_stream_switch,
|
||||
shared_expert_dp_enabled,
|
||||
shared_experts_compute_stream)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig = None):
|
||||
|
||||
super().__init__(moe=moe)
|
||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||
self.transpose = True
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
if self.transpose:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||
requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
self.transpose = False
|
||||
else:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||
requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype):
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
moe_counter = -1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
num_experts = kwargs["num_experts"]
|
||||
intermediate_size = kwargs["intermediate_size"]
|
||||
|
||||
AscendFusedMoE.moe_counter += 1
|
||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
||||
|
||||
self.expert_map = None
|
||||
self.log2phy = None
|
||||
|
||||
if self.quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
||||
self.moe_config)
|
||||
else:
|
||||
self.quant_method = self.quant_config.get_quant_method(
|
||||
self, self.layer_name)
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.moe_config.tp_group = get_tp_group()
|
||||
self.moe_config.dp_group = get_dp_group()
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
self.moe_config.mc2_group = get_mc2_group()
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.expert_map_path = ascend_config.expert_map_path
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||
# TODO: Flag for static expert placement. This is a temporary workaround
|
||||
# to allow dynamic EPLB with float weights by skipping quantization checks.
|
||||
self.static_eplb_enabled = False
|
||||
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
||||
dtype=vllm_config.model_config.dtype)
|
||||
# static eplb initializing with expert_map_path
|
||||
init_eplb_enable = False
|
||||
if self.expert_map_path and os.path.exists(
|
||||
self.expert_map_path) and os.access(self.expert_map_path,
|
||||
os.R_OK):
|
||||
self.expert_load_balancer = ExpertLoadBalancer(
|
||||
self.expert_map_path, num_experts)
|
||||
self.expert_load_balancer.check_expert_map_tensor()
|
||||
self.global_redundant_expert_num = (
|
||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||
try:
|
||||
self.local_num_experts, self.expert_map = (
|
||||
self.expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id, self.ep_rank))
|
||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id, self.ep_rank).npu()
|
||||
init_eplb_enable = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Init expert map of mtp/eagle when using sample.{e}")
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||
self.log2phy = determine_default_log2phy_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank).npu()
|
||||
else:
|
||||
# init moe.
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||
# dynamic eplb initializing with not expert_map_path
|
||||
if self.dynamic_eplb:
|
||||
self.log2phy = determine_default_log2phy_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank).npu()
|
||||
if self.expert_map is not None and isinstance(self.expert_map,
|
||||
torch.Tensor):
|
||||
logger.info_once(
|
||||
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
||||
" number of experts: %s/%s. Experts local to global index map:"
|
||||
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
|
||||
self.global_num_experts,
|
||||
get_compressed_expert_map(self.expert_map))
|
||||
local_num_experts = (torch.sum(
|
||||
self.expert_map != -1) if self.expert_map is not None else
|
||||
self.global_num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts,
|
||||
dtype=torch.int64).npu()
|
||||
|
||||
if init_eplb_enable and (
|
||||
not hasattr(self.quant_method, "quant_method")
|
||||
or not isinstance(self.quant_method.quant_method,
|
||||
AscendW8A8DynamicFusedMoEMethod)):
|
||||
raise ValueError("Eplb supports only w8a8_dynamic quantization.")
|
||||
|
||||
self.moe_config.num_experts = self.global_num_experts
|
||||
self.moe_config.num_local_experts = self.local_num_experts
|
||||
self.moe_config.original_num_experts = num_experts
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": self.hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": self.params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
setup_moe_comm_method(self.moe_config)
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
self.expert_map = new_expert_map
|
||||
|
||||
def get_map(self):
|
||||
return self.expert_map
|
||||
|
||||
def get_log2phy_map(self):
|
||||
return self.log2phy
|
||||
|
||||
def clear_moe_load(self):
|
||||
if self.moe_load is not None:
|
||||
self.moe_load.zero_()
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self, final_hidden_states: torch.Tensor):
|
||||
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
||||
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
||||
the outputs are already aggregated across tensor parallel ranks in the
|
||||
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
||||
outputs since each rank only has partial outputs.
|
||||
"""
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
# Load balancing for token distribution among experts in dummy_run
|
||||
# TODO: The community only considers load balancing when DP > 1.
|
||||
# This approach may overlook some extreme scenarios.
|
||||
enable_force_load_balance = forward_context.in_profile_run
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
shared_experts=None,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num)
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
reduce_results=self.reduce_results)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
|
||||
# Ensure training and inference weight shapes match during RL weight updates
|
||||
if (
|
||||
loaded_weight.shape[1] != expert_data.shape[1] and \
|
||||
loaded_weight.shape[0] != expert_data.shape[0]
|
||||
):
|
||||
shard_dim = int(not shard_dim)
|
||||
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
|
||||
return loaded_weight, shard_dim
|
||||
|
||||
def _load_w13(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
loaded_weight, shard_dim = self.transpose_weight(
|
||||
loaded_weight, expert_data, shard_dim)
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
loaded_weight, shard_dim = self.transpose_weight(
|
||||
loaded_weight, expert_data, shard_dim)
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
AscendFusedMoE.__init__(self, **kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
self.shared_expert_stream = None
|
||||
ascend_config = get_ascend_config()
|
||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
||||
if enable_sp():
|
||||
logger.info_once(
|
||||
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out, fused_out = AscendFusedMoE.forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
return shared_out, fused_out
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
shared_experts_compute_stream().wait_stream( # type: ignore
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(shared_experts_compute_stream(),
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Use a separate stream to run shared experts.
|
||||
# Note that currently we only support calculations in separate streams with aclgraph.
|
||||
# Communication operations in another stream might cause unknown errors.
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
fused_output = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
# Make sure the default stream waits for the shared experts stream to finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
shared_experts_compute_stream())
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
return shared_out, fused_output
|
||||
119
vllm_npu/ops/expert_load_balancer.py
Normal file
119
vllm_npu/ops/expert_load_balancer.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class ExpertLoadBalancer(object):
|
||||
|
||||
def __init__(self, expert_map_path, num_experts):
|
||||
self.expert_map_path = expert_map_path
|
||||
self.num_experts = num_experts
|
||||
self.tensor_data = []
|
||||
self.expert_map_tensor, self.layers_num, self.ranks_num = (
|
||||
self._expert_file_to_tensor())
|
||||
self.global_expert_num = num_experts + self.get_global_redundant_expert_num(
|
||||
)
|
||||
self.expert_placement_map = self.generate_expert_placement_map()
|
||||
|
||||
def _expert_file_to_tensor(self):
|
||||
with open(self.expert_map_path, "r") as f:
|
||||
data = json.load(f)
|
||||
layers_num = data["moe_layer_count"]
|
||||
gpus_num = data["layer_list"][0]["device_count"]
|
||||
|
||||
for layer in data["layer_list"]:
|
||||
device_data = []
|
||||
for device in layer["device_list"]:
|
||||
device_data.append(device["device_expert"])
|
||||
self.tensor_data.append(device_data)
|
||||
expert_map_tensor = torch.tensor(self.tensor_data, dtype=torch.int32)
|
||||
return expert_map_tensor, layers_num, gpus_num
|
||||
|
||||
def generate_index_dicts(self, tensor_2d):
|
||||
dict_list = []
|
||||
current_idx = 0
|
||||
|
||||
for row in tensor_2d:
|
||||
value_to_index = {}
|
||||
for i in range(row.size(0)):
|
||||
value = row[i].item()
|
||||
value_to_index[value] = current_idx + i
|
||||
dict_list.append(value_to_index)
|
||||
current_idx += row.size(0)
|
||||
|
||||
return dict_list
|
||||
|
||||
def generate_expert_placement_map(self):
|
||||
expert_placement_map = torch.full(
|
||||
(self.layers_num, self.ranks_num, self.num_experts),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for layer_id in range(self.layers_num):
|
||||
for gpu_id in range(self.ranks_num):
|
||||
e_ids = self.expert_map_tensor[layer_id, gpu_id]
|
||||
expert_placement_map[layer_id, gpu_id,
|
||||
e_ids] = torch.arange(len(e_ids),
|
||||
dtype=torch.int32)
|
||||
return expert_placement_map
|
||||
|
||||
def generate_log2phy_expert_map(self, layer_id):
|
||||
concatenated = torch.flatten(self.expert_map_tensor[layer_id])
|
||||
rank_expert_to_global = self.generate_index_dicts(
|
||||
self.expert_map_tensor[layer_id])
|
||||
result_dict: Dict[int, List[int]] = {}
|
||||
for idx, value in enumerate(concatenated):
|
||||
key = value.item()
|
||||
if key not in result_dict:
|
||||
result_dict[key] = []
|
||||
result_dict[key].append(idx)
|
||||
|
||||
log2phy_map = torch.full((self.ranks_num, self.num_experts),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
for rank in range(self.ranks_num):
|
||||
for key in result_dict:
|
||||
indices_in_concat = result_dict[key]
|
||||
if key in rank_expert_to_global[rank]:
|
||||
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
|
||||
else:
|
||||
chosen_index = random.choice(indices_in_concat)
|
||||
log2phy_map[rank][key] = chosen_index
|
||||
return log2phy_map
|
||||
|
||||
def get_rank_placement_map(self, layer_id, rank_id):
|
||||
layer_expert_map = self.expert_placement_map[layer_id]
|
||||
rank_expert_map = layer_expert_map[rank_id].to(
|
||||
torch.npu.current_device())
|
||||
rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item()
|
||||
return rank_local_expert_num, rank_expert_map
|
||||
|
||||
def get_rank_log2phy_map(self, layer_id, rank_id):
|
||||
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id)
|
||||
return layer_log2phy_map[rank_id]
|
||||
|
||||
def get_global_redundant_expert_num(self):
|
||||
global_redundant_expert_num = (
|
||||
len(self.expert_map_tensor[0][0]) * self.ranks_num -
|
||||
self.num_experts)
|
||||
return global_redundant_expert_num
|
||||
|
||||
def check_expert_map_tensor(self):
|
||||
if dist.is_initialized():
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
all_expert_maps = [None for _ in range(world_size)]
|
||||
dist.all_gather_object(all_expert_maps, self.tensor_data)
|
||||
for rank_id, expert_map_tensor in enumerate(all_expert_maps):
|
||||
if self.tensor_data != expert_map_tensor:
|
||||
raise ValueError(
|
||||
f"The expert map of rank{rank} is not equal to rank{rank_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"The expert maps of all ranks are inconsistency: {e}")
|
||||
299
vllm_npu/ops/fla.py
Normal file
299
vllm_npu/ops/fla.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
# mypy: ignore-errors
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
MAX_CORES = 65535
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X_base
|
||||
N, # number of columns in X_base
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
N_CORES: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X_base and Y_base it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
|
||||
BLOCK_ROWS = M if M < N_CORES else N_CORES
|
||||
n_iters = M // BLOCK_ROWS
|
||||
remain = M % BLOCK_ROWS
|
||||
if row < remain:
|
||||
n_iters = n_iters + 1
|
||||
|
||||
for i in tl.range(n_iters):
|
||||
X_base = X + (i * BLOCK_ROWS *
|
||||
stride_x_row) + row * stride_x_row + group * N
|
||||
Y_base = Y + (i * BLOCK_ROWS *
|
||||
stride_y_row) + row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z_base = Z + (i * BLOCK_ROWS *
|
||||
stride_z_row) + row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean_base = Mean + (i * BLOCK_ROWS) + group * M
|
||||
Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M
|
||||
W_base = W + group * N
|
||||
if HAS_BIAS:
|
||||
B_base = B + group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean_base + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd_base + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W_base + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B_base + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z_base + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y_base + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm else None)
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M if M < MAX_CORES else MAX_CORES, ngroups)
|
||||
with torch.npu.device(x.device.index):
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
N_CORES=MAX_CORES,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = F.normalize(query, p=2, dim=-1)
|
||||
key = F.normalize(key, p=2, dim=-1)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, sequence_length, num_heads, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
tot_heads = num_heads + pad_size
|
||||
scale = 1 / (query.shape[-1]**0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) -
|
||||
g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -(
|
||||
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
|
||||
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
|
||||
k_head_dim, v_head_dim).to(value) if
|
||||
initial_state is None else initial_state.to(value))
|
||||
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, tot_heads // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) *
|
||||
decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
|
||||
(k_i *
|
||||
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
|
||||
-1, -2) @ v_new)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
|
||||
core_attn_out.shape[1], -1,
|
||||
core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :num_heads]
|
||||
core_attn_out = core_attn_out.transpose(1,
|
||||
2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
@@ -1,36 +1,213 @@
|
||||
"""
|
||||
NPU-optimized layer normalization for Ascend.
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
|
||||
``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
||||
to NPU kernels automatically.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
|
||||
def _addrmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
layer: Optional[torch.nn.Module] = None,
|
||||
bias: Optional[torch.nn.Parameter] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
|
||||
if layer is not None and not is_310p():
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
try:
|
||||
weight_prefetch_method = get_forward_context(
|
||||
).weight_prefetch_method
|
||||
except AssertionError:
|
||||
weight_prefetch_method = None
|
||||
|
||||
# prefetch qkvo_proj.weight preprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# add_rms_norm_quant
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
beta=bias,
|
||||
epsilon=self.variance_epsilon)
|
||||
|
||||
# prefetch qkvo_proj.weight postprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
stop_flag=x,
|
||||
)
|
||||
|
||||
else:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if bias is not None:
|
||||
x.add_(bias)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
"""RMSNorm using Ascend NPU fused kernels.
|
||||
|
||||
Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
|
||||
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu # noqa: F401
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon
|
||||
)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear,
|
||||
self.bias)
|
||||
return x, residual
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
|
||||
@property
|
||||
def next_need_quant_fusion_linear(self):
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
if not forward_context.addrmsnorm_quant_fusion_enabled or \
|
||||
forward_context.layer_idx == forward_context.num_hidden_layers:
|
||||
return None
|
||||
except AssertionError:
|
||||
return None
|
||||
|
||||
next_linear = None
|
||||
model_instance = forward_context.model_instance
|
||||
layer_idx = forward_context.layer_idx
|
||||
fusion_linear = forward_context.fusion_linear
|
||||
next_linear = None
|
||||
if fusion_linear == "qkv_dense":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].self_attn.qkv_proj
|
||||
forward_context.fusion_linear = "gate_up_dense"
|
||||
elif fusion_linear == "gate_up_dense":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].mlp.gate_up_proj
|
||||
forward_context.fusion_linear = "qkv_dense"
|
||||
# if prefetch_mlp_weight enabled, following accumulation operation
|
||||
# does not need to be repeated
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
forward_context.layer_idx += 1
|
||||
elif fusion_linear == "qkv_moe":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].self_attn.qkv_proj
|
||||
forward_context.fusion_linear = "gate_moe"
|
||||
elif fusion_linear == "gate_moe":
|
||||
forward_context.fusion_linear = "qkv_moe"
|
||||
forward_context.layer_idx += 1
|
||||
from vllm_npu.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if next_linear is not None and \
|
||||
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||
next_linear = None
|
||||
return next_linear
|
||||
|
||||
|
||||
class AscendQuantRMSNorm(AscendRMSNorm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
x, residual = super().forward_oot(x, residual)
|
||||
return x.add_(self.bias), residual
|
||||
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
||||
|
||||
|
||||
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_npu.utils import is_310p
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
|
||||
466
vllm_npu/ops/linear.py
Normal file
466
vllm_npu/ops/linear.py
Normal file
@@ -0,0 +1,466 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
To customize linear communication groups or forward of classes in this file,
|
||||
extend new linear operations in linear_op.py.
|
||||
The classes in this file should not be modified, including AscendQKVParallelLinear,
|
||||
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear and AscendColumnParallelLinear.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import divide
|
||||
from vllm.model_executor.layers.linear import ( # noqa
|
||||
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
|
||||
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
||||
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_npu.ops.linear_op import get_parallel_op, get_replicated_op
|
||||
from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
"""Linear method without quantization"""
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
if (is_enable_nz(layer.weight.data.dtype)):
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
|
||||
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
|
||||
class AscendLinearBase(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.disable_tp = disable_tp
|
||||
|
||||
|
||||
class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
|
||||
Linear layers for the linear transformation of the query, key, and value
|
||||
vectors in the attention layer. The weight matrix is concatenated along
|
||||
the output dimension. The layer is parallelized along the head dimension.
|
||||
When the number of key/value heads is smaller than the number of query
|
||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||
be replicated while the query heads are partitioned.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self,
|
||||
"column")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size,
|
||||
self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
AscendColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "column")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.output_sizes = output_sizes
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
AscendColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""Linear layer with row parallelism.
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
# TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete.
|
||||
# TODO(MengqingCao): Remove the empty string check, after specifying the prefix in linear layers of some models in the vLLM.
|
||||
if prefix in compilation_config.static_forward_context and \
|
||||
prefix != "" and \
|
||||
"visual" not in prefix:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "row")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill: bool = True,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "column")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendReplicatedLinear(ReplicatedLinear):
|
||||
"""Ascend Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: Take no effect for replicated linear layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op = get_replicated_op(disable_tp, prefix, self)
|
||||
# If MergedReplicatedLinear, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = self.output_sizes
|
||||
else:
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
531
vllm_npu/ops/linear_op.py
Normal file
531
vllm_npu/ops/linear_op.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file extends the functionality of linear operations by encapsulating custom
|
||||
communication groups and forward functions into classes (linear ops).
|
||||
|
||||
Current class inheritance structure:
|
||||
CustomLinearOp
|
||||
├── CustomColumnParallelOp
|
||||
│ ├── MLPColumnParallelOp
|
||||
│ ├── SequenceColumnParallelOp
|
||||
└── CustomRowParallelOp
|
||||
│ ├── MLPRowParallelOp
|
||||
│ ├── OProjRowParallelOp
|
||||
│ ├── MatmulAllreduceRowParallelOp
|
||||
│ └── SequenceRowParallelOp
|
||||
└── CustomReplicatedOp
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
||||
3. Override the apply method according to requirements, which will replace the original linear.forward
|
||||
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments
|
||||
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_npu.utils import (dense_optim_enable, enable_sp,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable, shared_expert_dp_enabled)
|
||||
|
||||
|
||||
class CustomLinearOp:
|
||||
|
||||
def __init__(self, layer):
|
||||
self.layer = layer
|
||||
self.bias = None
|
||||
self.skip_bias_add = None
|
||||
self.return_bias = None
|
||||
self.quant_method = None
|
||||
|
||||
# Custom communication group, while determining weight sharding
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_tp_group()
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.comm_group.rank_in_group
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.comm_group.world_size
|
||||
|
||||
# Update the attributes required by apply(), obtaining them from the layer.
|
||||
# Call this after the layer completes its initialization, specifically at the end of layer.init().
|
||||
def update_attrs(self):
|
||||
if hasattr(self.layer, "bias"):
|
||||
self.bias = self.layer.bias
|
||||
self.skip_bias_add = self.layer.skip_bias_add
|
||||
self.return_bias = self.layer.return_bias
|
||||
self.quant_method = self.layer.quant_method
|
||||
self.prefix = self.layer.prefix
|
||||
|
||||
def apply_impl(self, input_):
|
||||
raise NotImplementedError
|
||||
|
||||
# Replace layer.forward to customize the layer computation process.
|
||||
def apply(self, input_):
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.gather_output = None
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.gather_output = self.layer.gather_output
|
||||
|
||||
|
||||
class CustomRowParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.reduce_results = None
|
||||
self.input_is_parallel = None
|
||||
self.input_size_per_partition = None
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.reduce_results = self.layer.reduce_results
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
def apply(self, input_):
|
||||
output, output_bias = self.apply_impl(input_)
|
||||
if dense_optim_enable():
|
||||
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomReplicatedOp(CustomLinearOp):
|
||||
|
||||
def apply_impl(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
|
||||
output = self.quant_method.apply(self.layer, input_, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
input_parallel = self.comm_group.all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self.layer, input_parallel, bias)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_mlp_tp_group()
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.layer.bias
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_otp_group()
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Prepare tensors for all-to-all communication
|
||||
local_batch_size = input_parallel.size(0)
|
||||
chunk_size = self.input_size_per_partition
|
||||
total_batch_size = local_batch_size * self.tp_size
|
||||
|
||||
# Reshape tensor for efficient cross-device transfer:
|
||||
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
|
||||
send_buf = (input_parallel.reshape(-1,
|
||||
self.tp_size, chunk_size).transpose(
|
||||
0, 1).contiguous().view(-1))
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_batch_size * chunk_size,
|
||||
dtype=input_parallel.dtype,
|
||||
device=input_parallel.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.comm_group.device_group)
|
||||
input_parallel = recv_buf.view(total_batch_size, chunk_size)
|
||||
|
||||
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
|
||||
# otp-specific: Combine partial results across devices
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
output = output.view(input_.shape[0], self.layer.output_size)
|
||||
|
||||
# Handle bias return based on configuration
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
|
||||
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
_HCOMM_INFO = None
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation."""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
def get_hcomm_info(cls, group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group."""
|
||||
if cls._HCOMM_INFO is not None:
|
||||
return cls._HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
cls._HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
else:
|
||||
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return cls._HCOMM_INFO
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.weight_t = self.layer.weight.t()
|
||||
|
||||
|
||||
class SequenceColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
if self.tp_size == 1 or not self.reduce_results:
|
||||
output = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
|
||||
self.prefix)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
||||
bias_: Optional[Parameter]) -> torch.Tensor:
|
||||
assert self.quant_method is not None
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
mmrs_fusion = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||
x,
|
||||
bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
|
||||
world_size = self.layer.tp_size
|
||||
comm_mode = "aiv"
|
||||
hcom_name = get_tp_group().device_group._get_backend(
|
||||
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
from vllm_npu.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_npu.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||
quant_per_tensor)
|
||||
|
||||
# For unquant
|
||||
if mmrs_fusion and isinstance(self.layer.quant_method,
|
||||
UnquantizedLinearMethod):
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x,
|
||||
self.layer.weight.t(),
|
||||
hcom_name,
|
||||
world_size,
|
||||
reduce_op="sum",
|
||||
bias=None,
|
||||
comm_turn=0,
|
||||
comm_mode=comm_mode)
|
||||
if bias_ is not None:
|
||||
output.add_(bias_)
|
||||
# For w8a8 quant
|
||||
elif mmrs_fusion and (
|
||||
isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)):
|
||||
if x.dtype != torch.int8:
|
||||
x_quant = quant_per_tensor(
|
||||
x, self.layer.aclnn_input_scale_reciprocal,
|
||||
self.layer.aclnn_input_offset)
|
||||
else:
|
||||
x_quant = x
|
||||
quant_bias = self.layer.quant_bias
|
||||
deq_scale = self.layer.deq_scale
|
||||
output_dtype = torch.bfloat16
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x_quant,
|
||||
self.layer.weight,
|
||||
hcom_name,
|
||||
world_size,
|
||||
reduce_op="sum",
|
||||
bias=None,
|
||||
comm_turn=0,
|
||||
x2_scale=deq_scale,
|
||||
output_dtype=output_dtype,
|
||||
comm_mode=comm_mode)
|
||||
output = torch.add(
|
||||
output,
|
||||
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
|
||||
else:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||
x,
|
||||
bias=bias_)
|
||||
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
|
||||
|
||||
return output
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.reduce_results = self.layer.reduce_results
|
||||
|
||||
|
||||
def _get_column_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
|
||||
if mlp_tp_enable() and "gate_up_proj" in prefix:
|
||||
return MLPColumnParallelOp(layer)
|
||||
if enable_sp():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
if "gate_up_proj" in prefix:
|
||||
return SequenceColumnParallelOp(layer)
|
||||
if "in_proj" in prefix:
|
||||
return SequenceColumnParallelOp(layer)
|
||||
if "qkv_proj" in prefix or "conv1d" in prefix:
|
||||
return SequenceColumnParallelOp(layer)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_row_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
return MLPRowParallelOp(layer)
|
||||
if "o_proj" in prefix and oproj_tp_enable():
|
||||
return OProjRowParallelOp(layer)
|
||||
if matmul_allreduce_enable():
|
||||
return MatmulAllreduceRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix:
|
||||
return SequenceRowParallelOp(layer)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
if disable_tp or ("shared_experts" in prefix
|
||||
and shared_expert_dp_enabled()):
|
||||
return None, 0, 1
|
||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]] = None
|
||||
if direct == "row":
|
||||
custom_op = _get_row_parallel_op(prefix, layer)
|
||||
|
||||
if direct == "column":
|
||||
custom_op = _get_column_parallel_op(prefix, layer)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_replicated_op(disable_tp, prefix,
|
||||
layer) -> Optional[Union[CustomReplicatedOp]]:
|
||||
if disable_tp:
|
||||
return None
|
||||
|
||||
return CustomReplicatedOp(layer)
|
||||
0
vllm_npu/ops/moe/__init__.py
Normal file
0
vllm_npu/ops/moe/__init__.py
Normal file
113
vllm_npu/ops/moe/comm_utils.py
Normal file
113
vllm_npu/ops/moe/comm_utils.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
a2a_out = input_.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
if event:
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
return input_, a2a_out, handle
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
277
vllm_npu/ops/moe/experts_selector.py
Normal file
277
vllm_npu/ops/moe/experts_selector.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
global_num_experts: int = -1):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
|
||||
Args:
|
||||
router_logits: router logits of shape (num_tokens, hidden_size).
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
top_k: number of top k experts.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
indices_type: dtype of indices
|
||||
global_num_experts: Global number of experts.
|
||||
|
||||
Returns:
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_weights is None:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _renormalize_topk_weights(
|
||||
topk_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
):
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _select_expert_use_group_topk(
|
||||
topk_weights: torch.Tensor, topk_group: Optional[int],
|
||||
renormalize: bool, top_k: int, num_expert_group: Optional[int],
|
||||
e_score_correction_bias: Optional[torch.Tensor]):
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_weights = topk_weights
|
||||
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids = None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _native_select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = router_logits.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights = router_logits.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
if use_grouped_topk:
|
||||
return _select_expert_use_group_topk(
|
||||
topk_weights=topk_weights,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if custom_routing_function is not None:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
520
vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py
Normal file
520
vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_npu.utils import enable_sp
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
in distributed environments. Subclasses implement specific communication strategies
|
||||
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
|
||||
broadcasting, and reduction across TP/DP/EP groups.
|
||||
|
||||
Attributes:
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
- Slicing across tensor-parallel ranks
|
||||
- Broadcasting across data-parallel ranks
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||
- processed router_logits (may be recomputed or broadcasted)
|
||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
- Reducing or scattering across DP ranks
|
||||
- Unpadding to original token count
|
||||
- Applying all-reduce across TP/EP if requested
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
|
||||
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
|
||||
"""
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using MC2 (Memory-Centric Communication).
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""
|
||||
Restore original TP configuration.
|
||||
vLLM flattens TP and DP into a single dimension; this method recovers
|
||||
the true TP world size and rank for correct tensor slicing.
|
||||
"""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
2. Pad `hidden_states` and `router_logits` to target length if needed.
|
||||
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
|
||||
4. Split and return corresponding `mc2_mask`.
|
||||
|
||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
if pad_size > 0 and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
return hidden_states, router_logits, mc2_mask
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
|
||||
2. Unpad to original token count if padding was applied.
|
||||
3. Return tensor with shape [original_num_tokens, hidden_size].
|
||||
|
||||
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
# All-gather across TP group
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
# Unpad if necessary
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""Restore original TP configuration (same as MC2)."""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
2. If TP > 1, split along token dim and select current TP rank's slice.
|
||||
3. Save splits for later all-gather in finalize.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
2. Unpad to original token count.
|
||||
3. Return [original_num_tokens, hidden_size] tensor.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
|
||||
There are two sets of prepare and finalize:
|
||||
1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
|
||||
we gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
The communication and calculation process is as follows (AG, AR and RS
|
||||
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
|
||||
|
||||
Attn → TP AR → DP AG → MoE → DP RS → TP AR
|
||||
|
||||
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
|
||||
the above process becomes:
|
||||
|
||||
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
|
||||
|
||||
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
|
||||
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
|
||||
|
||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
2. Pad local tensors to that size.
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
def _finalize_with_ep_group(self,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
|
||||
1. Reduce_results is False usually happens when models have shared experts and need to
|
||||
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
|
||||
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
|
||||
result of shared experts.
|
||||
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
|
||||
here, then skip allreudce in FusedMoe.
|
||||
"""
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states, True)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
2. Slice to original local token count.
|
||||
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [original_local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using Naive Multicast (point-to-point broadcast).
|
||||
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
|
||||
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
|
||||
"""
|
||||
|
||||
def _naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
"""
|
||||
Naive multicast implementation:
|
||||
1. Create global buffer sized by total tokens across DP.
|
||||
2. Current rank copies its slice into its designated buffer region.
|
||||
3. Each rank broadcasts its slice to all others via P2P.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
|
||||
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Global tensor [total_tokens, hidden_size]
|
||||
"""
|
||||
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
# Copy local slice into buffer
|
||||
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
|
||||
# Broadcast each slice to all ranks
|
||||
for idx in range(self.moe_config.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
2. Multicast hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
router_logits = self._naive_multicast(router_logits,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert:
|
||||
- All-reduce across DP
|
||||
- Slice to current rank's token range using cu_tokens_across_dp_cpu
|
||||
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
hidden_states = get_dp_group().all_reduce(
|
||||
hidden_states) # Sum across DP
|
||||
hidden_states = hidden_states[start:end, :]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
273
vllm_npu/ops/moe/moe_comm_method.py
Normal file
273
vllm_npu/ops/moe/moe_comm_method.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_npu.ascend_forward_context import MoECommType
|
||||
from vllm_npu.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_npu.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_npu.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
TokenDispatcherWithMoge)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
|
||||
def get_moe_comm_method(
|
||||
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
|
||||
return _MoECommMethods.get(moe_comm_type, None)
|
||||
|
||||
|
||||
def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||
moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.model_type = get_current_vllm_config(
|
||||
).model_config.hf_config.model_type
|
||||
self.moe_config = moe_config
|
||||
self.mc2_mask = None
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce)
|
||||
self.mc2_mask = mc2_mask
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
hidden_states = self.fused_moe_prepare_finalize.finalize(
|
||||
hidden_states, reduce_results)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=self.mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
||||
|
||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=expert_tokens,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=use_int8_w8a8
|
||||
or use_int4_w4a8,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
final_hidden_states = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (final_hidden_states, group_list_type, expert_tokens)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self):
|
||||
raise NotImplementedError(
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
raise NotImplementedError(
|
||||
"_get_fused_moe_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
if self.model_type == "PanguProMoE":
|
||||
return TokenDispatcherWithMoge(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
else:
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class NaiveMulticastCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
266
vllm_npu/ops/moe/moe_mlp.py
Normal file
266
vllm_npu/ops/moe/moe_mlp.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import pad
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_forward_context import MoECommType
|
||||
from vllm_npu.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def cumsum_group_list(group_list: torch.Tensor,
|
||||
src_list_type: int,
|
||||
dst_list_type: int,
|
||||
active_num: int = 0,
|
||||
expert_num: int = 0) -> torch.Tensor:
|
||||
if src_list_type not in [0, 1, 2]:
|
||||
raise ValueError(
|
||||
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
|
||||
)
|
||||
|
||||
if src_list_type == dst_list_type:
|
||||
return group_list
|
||||
if src_list_type == 1 and dst_list_type == 0:
|
||||
return group_list.cumsum(dim=0)
|
||||
if src_list_type == 0 and dst_list_type == 1:
|
||||
group_diff = torch.diff(group_list)
|
||||
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
|
||||
return new_group
|
||||
if src_list_type == 2 and dst_list_type == 0:
|
||||
experts = pad(group_list[:, 0], (1, 0))
|
||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||
fill_value=active_num,
|
||||
dtype=group_list.dtype,
|
||||
device=group_list.device)
|
||||
|
||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||
if end > start:
|
||||
cumsum_group_list[start:end] = tokens[i]
|
||||
|
||||
return cumsum_group_list
|
||||
raise NotImplementedError(
|
||||
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
|
||||
"This feature is under development.")
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
fusion: bool = False,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||
hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
if fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale)
|
||||
else:
|
||||
if w1_scale.dtype != torch.float32:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=cumsum_group_list(group_list, group_list_type, 1),
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
if fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
bias=bias1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale)
|
||||
else:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale.to(w2_scale.dtype)],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
need_trans: bool = True) -> torch.Tensor:
|
||||
|
||||
if need_trans:
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False,
|
||||
fusion: bool = False,
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales,
|
||||
need_trans=need_trans)
|
||||
725
vllm_npu/ops/moe/token_dispatcher.py
Normal file
725
vllm_npu/ops/moe/token_dispatcher.py
Normal file
@@ -0,0 +1,725 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_npu.distributed.parallel_state import get_mc2_group
|
||||
from vllm_npu.ops.moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_npu.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.top_k = kwargs.get("top_k", 0)
|
||||
self.num_experts = kwargs.get("num_experts", 0)
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@abstractmethod
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (
|
||||
get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.shared_act = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
self.expand_scales = None
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
):
|
||||
quant_mode = 2 if self.with_quant else 0
|
||||
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"expert_token_nums_type": 0,
|
||||
}
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
topk_weights.to(torch.float32),
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
self.with_quant = with_quant
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
self.shared_experts = shared_experts
|
||||
self.mc2_mask = mc2_mask
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
global_redundant_expert_num)
|
||||
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
|
||||
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
}
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
|
||||
assert self.expert_map is not None
|
||||
assert self.topk_weights is not None
|
||||
assert self.topk_ids is not None
|
||||
assert self.output is not None
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
tp_recv_counts = self.output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
"expand_scales": self.expand_scales,
|
||||
}
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
self.assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": self.assist_info_for_combine,
|
||||
})
|
||||
if self.need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.expert_map = None
|
||||
self.expand_scales = None
|
||||
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
if self.with_quant:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
else:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
self.shared_act = None
|
||||
self.shared_experts = None
|
||||
self.swiglu_out_scale = None
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
self.num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.sorted_weights = None
|
||||
self.expanded_row_idx = None
|
||||
self.sorted_token_indices = None
|
||||
self.original_shape = None
|
||||
self.mask = None
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
mask = (expert_map[topk_ids] != -1)
|
||||
self.topk_weights = topk_weights * mask
|
||||
first_expert_idx = get_ep_group(
|
||||
).rank_in_group * self.num_experts_local
|
||||
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||
else:
|
||||
first_expert_idx = 0
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant else -1,
|
||||
))
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.original_shape is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=torch.abs(self.expanded_row_idx),
|
||||
probs=self.topk_weights)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.expanded_row_idx = None
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# mypy: disable-error-code="override"
|
||||
class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.local_num_experts = self.num_experts // self.ep_size
|
||||
self.local_num_group = self.top_k // self.ep_size
|
||||
self.bsz = None
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, self.sorted_topk_ids // self.local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
self.local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (
|
||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": group_list,
|
||||
"topk_scales": topk_scales,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
self.bsz, self.top_k // self.ep_size, -1).sum(1)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.with_quant = False
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
|
||||
self.hidden_shape = None
|
||||
self.topk_weights = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.tokens_per_expert = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.topk_weights = topk_weights
|
||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
|
||||
hidden_states, topk_ids)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
|
||||
dynamic_scale_after_all2all = None
|
||||
if self.with_quant:
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
permutated_local_input_tokens)
|
||||
|
||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||
dynamic_scale,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute2_ep_all_to_all_handle.wait()
|
||||
dynamic_scale.untyped_storage().resize_(0)
|
||||
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all)
|
||||
return {
|
||||
"hidden_states": global_input_tokens,
|
||||
"group_list": tokens_per_expert,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"group_list_type": 1
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
hidden_states = self._combine_preprocess(hidden_states)
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits,
|
||||
self.ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
output = self._combine_postprocess(permutated_local_input_tokens)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.topk_weights = None
|
||||
self.reversed_local_input_permutation_mapping = None
|
||||
self.reversed_global_input_permutation_mapping = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
return output
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self._preprocess(topk_ids)
|
||||
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=topk_ids,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = topk_ids.numel()
|
||||
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size,
|
||||
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
|
||||
non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum.")
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
|
||||
if self.num_local_experts > 1:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
else:
|
||||
# TODO: This full synchronization can be a performance bottleneck.
|
||||
# A more granular sync (e.g., blocking D2H copies) should be investigated.
|
||||
torch.npu.synchronize()
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
|
||||
# Early return if no local experts or no tokens
|
||||
if self.num_local_experts <= 1:
|
||||
return global_input_tokens, None
|
||||
|
||||
# Handle quantized case
|
||||
if self.with_quant:
|
||||
assert self.global_input_tokens_local_experts_indices is not None, \
|
||||
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
|
||||
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
|
||||
-1)
|
||||
active_num = self.global_input_tokens_local_experts_indices.numel()
|
||||
|
||||
# Handle case with no active tokens
|
||||
if active_num <= 0:
|
||||
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
|
||||
return global_input_tokens, dynamic_scale
|
||||
|
||||
# Process with active tokens
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
global_input_tokens,
|
||||
expert_idx_2d,
|
||||
scale=dynamic_scale,
|
||||
active_num=active_num,
|
||||
expert_capacity=0,
|
||||
expert_num=self.num_local_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[0, self.num_local_experts],
|
||||
quant_mode=-1,
|
||||
row_idx_type=0)
|
||||
return global_input_tokens, expanded_scale
|
||||
|
||||
# Handle non-quantized case
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
return global_input_tokens, None
|
||||
|
||||
def _combine_preprocess(self, hidden_states):
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states, self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_postprocess(self, permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.to(
|
||||
torch.int32),
|
||||
probs=self.topk_weights,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
315
vllm_npu/ops/register_custom_ops.py
Normal file
315
vllm_npu/ops/register_custom_ops.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_forward_context import MoECommType
|
||||
from vllm_npu.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_npu.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size, :]
|
||||
else:
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty(
|
||||
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dp_size = get_dp_group().world_size
|
||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
result[offset:offset +
|
||||
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = \
|
||||
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
padded_x = torch.empty(
|
||||
(dp_size, forward_context.padded_length, *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
|
||||
offset += num_tokens_dp
|
||||
|
||||
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
|
||||
0)
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
model_instance = forward_context.model_instance
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
layer_idx = int(prefix.split('.')[2])
|
||||
|
||||
# start point of gate_up_proj weight prefetch
|
||||
if prefix.split('.')[-2] == "self_attn":
|
||||
forward_context.prefetch_mlp_gate_up_proj = True
|
||||
if forward_context.prefetch_mlp_gate_up_proj:
|
||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
mlp_gate_up_prefetch_size = envs_ascend.vllm_npu_MLP_GATE_UP_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
|
||||
x_dependency, mlp_gate_up_prefetch_size)
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
|
||||
if get_forward_context().sp_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
forward_context.prefetch_mlp_down_proj = True
|
||||
model_instance = forward_context.model_instance
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
layer_idx = forward_context.layer_idx
|
||||
|
||||
# start point of down_proj weight prefetch
|
||||
prefetch_stream.wait_stream(torch.npu.current_stream())
|
||||
|
||||
with torch.npu.stream(prefetch_stream):
|
||||
mlp_down_prefetch_size = envs_ascend.vllm_npu_MLP_DOWN_PREFETCH_SIZE
|
||||
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
|
||||
x_dependency, mlp_down_prefetch_size)
|
||||
forward_context.layer_idx += 1
|
||||
return
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_down_proj_impl_fake(
|
||||
x_dependency: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return
|
||||
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
return
|
||||
if forward_context.prefetch_mlp_gate_up_proj or \
|
||||
forward_context.prefetch_mlp_down_proj:
|
||||
prefetch_stream = forward_context.prefetch_stream
|
||||
# wait until prefetch done
|
||||
torch.npu.current_stream().wait_stream(prefetch_stream)
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
return
|
||||
|
||||
|
||||
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
weight_prefetch_stream.wait_stream(calculation_stream)
|
||||
with npu_stream_switch(weight_prefetch_stream):
|
||||
maybe_npu_prefetch(inputs=weight,
|
||||
dependency=start_flag,
|
||||
max_size=max_weight_size)
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
|
||||
start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
calculation_stream.wait_stream(weight_prefetch_stream)
|
||||
|
||||
|
||||
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.custom_op is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output = self.custom_op.matmul_and_reduce(input_parallel, bias_)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.sp_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(size=(num_tokens, self.output_size_per_partition),
|
||||
device=input_parallel.device,
|
||||
dtype=input_parallel.dtype)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||
op_func=_maybe_pad_and_reduce_impl,
|
||||
fake_impl=_maybe_pad_and_reduce_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
|
||||
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
|
||||
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
|
||||
op_func=_maybe_prefetch_mlp_down_proj_impl,
|
||||
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
||||
op_func=_maybe_wait_prefetch_done_impl,
|
||||
fake_impl=_maybe_wait_prefetch_done_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="prefetch_preprocess",
|
||||
op_func=_prefetch_preprocess_impl,
|
||||
fake_impl=_prefetch_preprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="prefetch_postprocess",
|
||||
op_func=_prefetch_postprocess_impl,
|
||||
fake_impl=_prefetch_postprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="matmul_and_reduce",
|
||||
op_func=_matmul_and_reduce_impl,
|
||||
fake_impl=_matmul_and_reduce_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user