diff --git a/setup.py b/setup.py index be5731c..7c854cc 100644 --- a/setup.py +++ b/setup.py @@ -20,5 +20,9 @@ setup( "vllm.platform_plugins": [ "npu = vllm_npu:register", ], + "vllm.general_plugins": [ + "npu_enhanced_model = vllm_npu:register_model", + "npu_kv_connector = vllm_npu:register_connector", + ], }, ) diff --git a/vllm_npu/__init__.py b/vllm_npu/__init__.py index b03577b..b6c4c13 100644 --- a/vllm_npu/__init__.py +++ b/vllm_npu/__init__.py @@ -1,38 +1,33 @@ -""" -vllm_npu — Ascend NPU platform plugin for vLLM. - -The ``register()`` function is discovered by vLLM through the -``vllm.platform_plugins`` entry-point and returns the fully-qualified -class name of the platform implementation. -""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# def register(): - """Return the fully-qualified name of the NPU platform class.""" - # Apply CUDA→NPU compatibility patches early so that any code - # referencing torch.cuda.Stream / Event / etc. will transparently - # be redirected to the torch.npu equivalents. - from vllm_npu.cuda_compat import _patch_cuda_to_npu - _patch_cuda_to_npu() + """Register the NPU platform.""" return "vllm_npu.platform.NPUPlatform" -def register_npu_ops(): - """Register Ascend NPU op overrides with vLLM's CustomOp system. +def register_model(): - Must be called AFTER the platform is established (e.g., during - worker init or check_and_update_config), NOT during register(). - """ - from vllm.model_executor.custom_op import CustomOp + from .models import register_model + register_model() - from vllm_npu.ops.activation import AscendSiluAndMul - from vllm_npu.ops.layernorm import AscendRMSNorm - from vllm_npu.ops.rotary_embedding import AscendRotaryEmbedding - for name, op_cls in { - "SiluAndMul": AscendSiluAndMul, - "RMSNorm": AscendRMSNorm, - "RotaryEmbedding": AscendRotaryEmbedding, - }.items(): - CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) +def register_connector(): + from vllm_npu.distributed import register_connector + register_connector() diff --git a/vllm_npu/ascend_config.py b/vllm_npu/ascend_config.py new file mode 100644 index 0000000..b0973b1 --- /dev/null +++ b/vllm_npu/ascend_config.py @@ -0,0 +1,310 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from vllm.logger import logger + +TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"] + + +def _check_torchair_supported(model_type: str): + for supported_model in TORCHAIR_MODEL_LIST: + if supported_model in model_type.lower(): + return True + return False + + +class AscendConfig: + """ + Configuration Object for additional_config from vllm.configs. + """ + + def __init__(self, vllm_config): + additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} + + torchair_graph_config = additional_config.get("torchair_graph_config", + {}) + self.torchair_graph_config = TorchairGraphConfig( + torchair_graph_config, vllm_config, additional_config) + + ascend_scheduler_config = additional_config.get( + "ascend_scheduler_config", {}) + self.ascend_scheduler_config = AscendSchedulerConfig( + ascend_scheduler_config) + + weight_prefetch_config = additional_config.get( + "weight_prefetch_config", {}) + self.weight_prefetch_config = WeightPrefetchConfig( + weight_prefetch_config) + + # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config + self.expert_map_path = additional_config.get("expert_map_path", None) + self.eplb_policy_type = additional_config.get("eplb_policy_type", 1) + self.expert_map_record_path = additional_config.get( + "expert_map_record_path", + None) # Provide path to export expert map + self.init_redundancy_expert = additional_config.get( + "init_redundancy_expert", 0) + self.dynamic_eplb = additional_config.get("dynamic_eplb", False) + self.num_iterations_eplb_update = additional_config.get( + "num_iterations_eplb_update", 400) + self.gate_eplb = additional_config.get("gate_eplb", False) + self.num_wait_worker_iterations = additional_config.get( + "num_wait_worker_iterations", 30) + self.chunked_prefill_for_mla = additional_config.get( + "chunked_prefill_for_mla", False) + self.enable_shared_expert_dp = additional_config.get( + "enable_shared_expert_dp", False + ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + self.multistream_overlap_shared_expert = additional_config.get( + "multistream_overlap_shared_expert", False) + self.recompute_scheduler_enable = additional_config.get( + "recompute_scheduler_enable", False) + self.lmhead_tensor_parallel_size = additional_config.get( + "lmhead_tensor_parallel_size", None) + if self.lmhead_tensor_parallel_size is not None: + logger.info( + f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario" + ) + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise AssertionError( + "lmhead_tensor_parallel_size is only supported in the pure DP scenario" + ) + self.oproj_tensor_parallel_size = additional_config.get( + "oproj_tensor_parallel_size", None) + if self.oproj_tensor_parallel_size is not None: + logger.info( + f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario" + ) + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in the pure DP scenario" + ) + if not self.torchair_graph_config.enabled: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in graph mode" + ) + if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." + ) + self.enable_cpu_binding = additional_config.get( + "enable_cpu_binding", False) + self.pd_tp_ratio = 1 + self.pd_head_ratio = 1 + self.num_head_replica = 1 + if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla: + prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {"tp_size": 1})["tp_size"] + decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {"tp_size": 1})["tp_size"] + assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size." + self.pd_tp_ratio = prefill_tp_size // decode_tp_size + if self.pd_tp_ratio > 1: + try: + # only support Qwen model now + # TODO: use a more robust method to get kv_head_num + num_kv_head = vllm_config.model_config.hf_config.num_key_value_heads + self.num_head_replica = prefill_tp_size // num_kv_head if prefill_tp_size >= num_kv_head else 1 + prefill_tp_size = min(prefill_tp_size, num_kv_head) + decode_tp_size = min(decode_tp_size, num_kv_head) + self.pd_head_ratio = prefill_tp_size // decode_tp_size + except Exception: + raise AssertionError( + "Can not get num_key_value_heads from model_config") + + if self.pd_tp_ratio == 0: + raise AssertionError( + "Only support P node tp size lagger then D node tp size") + + +class TorchairGraphConfig: + """ + Configuration Object for torchair_graph_config from additional_config + """ + + def __init__(self, torchair_graph_config, vllm_config, additional_config): + self.enabled = torchair_graph_config.get("enabled", False) + self.mode = torchair_graph_config.get("mode", '') + self.use_cached_graph = torchair_graph_config.get( + "use_cached_graph", False) + self.use_cached_kv_cache_bytes = torchair_graph_config.get( + "use_cached_kv_cache_bytes", False) + self.graph_batch_sizes = torchair_graph_config.get( + "graph_batch_sizes", []) + self.graph_batch_sizes_init = torchair_graph_config.get( + "graph_batch_sizes_init", False) + self.enable_multistream_mla = torchair_graph_config.get( + "enable_multistream_mla", False) + self.enable_view_optimize = torchair_graph_config.get( + "enable_view_optimize", True) + self.enable_frozen_parameter = torchair_graph_config.get( + "enable_frozen_parameter", True) + self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) + self.enable_super_kernel = torchair_graph_config.get( + "enable_super_kernel", False) + + if not isinstance(self.graph_batch_sizes, list): + raise TypeError("graph_batch_sizes must be list[int]") + if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0: + raise ValueError( + "graph_batch_sizes_init is only valid when graph_batch_sizes is empty" + ) + if not self.enabled: + if self.mode: + raise RuntimeError( + "mode is valid only when Torchair graph mode is enabled") + if self.use_cached_graph: + raise RuntimeError( + "use_cached_graph is valid only when Torchair graph mode is enabled" + ) + if self.use_cached_kv_cache_bytes: + raise RuntimeError( + "use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled" + ) + if self.graph_batch_sizes: + raise RuntimeError( + "graph_batch_sizes is valid only when Torchair graph mode is enabled" + ) + if self.graph_batch_sizes_init: + raise RuntimeError( + "graph_batch_sizes_init is valid only when Torchair graph mode is enabled" + ) + if self.enable_multistream_mla: + raise RuntimeError( + "enable_multistream_mla is valid only when Torchair graph mode is enabled" + ) + if self.enable_kv_nz: + raise RuntimeError( + "enable_kv_nz is valid only when Torchair graph mode is enabled" + ) + if self.enable_super_kernel: + raise RuntimeError( + "enable_super_kernel is valid only when Torchair graph mode is enabled" + ) + if self.enable_super_kernel: + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise RuntimeError( + "enable_super_kernel is valid only when tensor_parallel_size is 1" + ) + if not additional_config.get("multistream_overlap_shared_expert", + False): + raise RuntimeError( + "enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled" + ) + if self.use_cached_kv_cache_bytes and not self.use_cached_graph: + raise RuntimeError( + "use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled" + ) + + +class AscendSchedulerConfig: + """ + Configuration Object for ascend_scheduler_config from additional_config + """ + + def __init__(self, ascend_scheduler_config: dict): + self.enabled = ascend_scheduler_config.get("enabled", False) + # Ascend scheduler is based on vllm v0 scheduler, so we should support + # all vllm v0 scheduler configs as well. + for k, v in ascend_scheduler_config.items(): + if not hasattr(self, k): + setattr(self, k, v) + + +class WeightPrefetchConfig: + """ + Configuration Object for weight_prefetch_config from additional_config + """ + + prefetch_ratio: dict = { + "attn": { + "qkv": 1.0, + "o": 1.0, + }, + "moe": { + "gate_up": 0.8 + } + } + + def __init__(self, weight_prefetch_config: dict): + self.enabled = weight_prefetch_config.get("enabled", False) + self.prefetch_ratio = weight_prefetch_config.get( + "prefetch_ratio", self.prefetch_ratio) + + +_ASCEND_CONFIG: Optional[AscendConfig] = None + + +def init_ascend_config(vllm_config): + additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} + refresh = additional_config.get("refresh", + False) if additional_config else False + global _ASCEND_CONFIG + if _ASCEND_CONFIG is not None and not refresh: + return _ASCEND_CONFIG + _ASCEND_CONFIG = AscendConfig(vllm_config) + return _ASCEND_CONFIG + + +def clear_ascend_config(): + global _ASCEND_CONFIG + _ASCEND_CONFIG = None + + +def get_ascend_config(): + global _ASCEND_CONFIG + if _ASCEND_CONFIG is None: + raise RuntimeError( + "Ascend config is not initialized. Please call init_ascend_config first." + ) + return _ASCEND_CONFIG + + +def check_ascend_config(vllm_config, enforce_eager): + ascend_config = get_ascend_config() + + # for eager mode + if enforce_eager: + # torchair_graph cannot be enabled with eager mode. + if ascend_config.torchair_graph_config.enabled: + raise RuntimeError( + "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode." + ) + # for graph mode + else: + # torchair_graph case + if ascend_config.torchair_graph_config.enabled: + # torchair_graph is supported for deepseek/pangu/qwen model only. + if vllm_config.model_config: + model_type = vllm_config.model_config.hf_config.model_type + if not _check_torchair_supported(model_type): + raise NotImplementedError( + "Torchair graph mode only works with following model types:" + f"{TORCHAIR_MODEL_LIST}.") + if ascend_config.enable_shared_expert_dp: + logger.warning( + "enable_shared_expert_dp is not supported for torchair graph mode currently, " + "it has been disabled automatically.") + # aclgraph case + else: + if vllm_config.model_config: + model_type = vllm_config.model_config.hf_config.model_type + if "qwen" not in model_type: + logger.warning( + "ACL Graph is currently experimental. Please " + "raise an issue on https://github.com/vllm-project/vllm-ascend/issues" + " if you encourage any Error") diff --git a/vllm_npu/ascend_forward_context.py b/vllm_npu/ascend_forward_context.py new file mode 100644 index 0000000..f268f45 --- /dev/null +++ b/vllm_npu/ascend_forward_context.py @@ -0,0 +1,211 @@ +import math +from contextlib import contextmanager +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional + +import torch +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size) +from vllm.forward_context import (BatchDescriptor, get_forward_context, + set_forward_context) + +import vllm_npu.envs as envs_ascend +from vllm_npu.utils import enable_sp, has_layer_idx, is_moe_model + +if TYPE_CHECKING: + from vllm_npu.ops.weight_prefetch import WeightPrefetchMethod +else: + WeightPrefetchMethod = None + + +class FusedMoEState(Enum): + AllGather = 0 + All2All = 1 + MC2 = 2 + AllGatherEP = 3 + NaiveMulticast = 4 + All2AllSeq = 5 + + +class MoECommType(Enum): + ALLGATHER = 0 + MC2 = 1 + ALLTOALL = 2 + NAIVE_MULTICAST = 3 + + +# TODO(zzzzwwjj): add soc_version to choose branch +def _get_fused_moe_state(ep_size: int, with_prefill: bool, + is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return FusedMoEState.AllGatherEP + elif ep_size == 1: + if with_prefill: + return FusedMoEState.NaiveMulticast + else: + return FusedMoEState.AllGather + # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. + elif ep_size < 16 or with_prefill: + return FusedMoEState.All2All + else: + return FusedMoEState.MC2 + + +@contextmanager +def set_ascend_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + with_prefill: bool = True, + in_profile_run: bool = False, + reserved_mc2_mask: Optional[torch.Tensor] = None, + moe_comm_type: Optional[MoECommType] = None, + num_actual_tokens: Optional[int] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + prefetch_stream: torch.npu.Stream = None, + model_instance: torch.nn.Module = None, + weight_prefetch_method: Optional[WeightPrefetchMethod] = None): + """A context manager that stores the current forward context, + can be attention metadata, etc. + We add some additional param into forward_context. + """ + with set_forward_context( + attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + ): + forward_context = get_forward_context() + + from vllm_npu.ops.moe.moe_comm_method import get_moe_comm_method + forward_context.moe_comm_type = moe_comm_type + forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) + + forward_context.with_prefill = with_prefill + tp_world_size = get_tensor_model_parallel_world_size() + ep_size = (get_ep_group().world_size if + vllm_config.parallel_config.enable_expert_parallel else 1) + + is_deepseek_v3_r1 = hasattr( + vllm_config.model_config.hf_config, 'n_routed_experts' + ) and vllm_config.model_config.hf_config.n_routed_experts == 256 + fused_moe_state = _get_fused_moe_state(ep_size, with_prefill, + is_deepseek_v3_r1) + forward_context.fused_moe_state = fused_moe_state + forward_context.in_profile_run = in_profile_run + + # NOTE: This cannot be set using set_forward_context + # due to multiple warmups before actual capturing + forward_context.capturing = False + + # set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature. + # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, + # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, + # the performance may degrade due to the switching of communication methods. + mmrs_fusion = True + if is_moe_model(vllm_config): + sp_enabled = enable_sp(vllm_config) and \ + tp_world_size > 1 and num_tokens is not None + mmrs_fusion = False + else: + sp_enabled = enable_sp(vllm_config) and \ + tp_world_size > 1 and \ + num_tokens is not None and num_tokens > 1000 + forward_context.mmrs_fusion = mmrs_fusion + + if sp_enabled: + pad_size = (tp_world_size - + (num_tokens % tp_world_size)) % tp_world_size + forward_context.pad_size = pad_size + forward_context.sp_enabled = sp_enabled + forward_context.num_tokens = num_tokens + + # set this for rope forward_oot using + forward_context.is_first_layer = True + + # set layer_idx to enable optimization features that depend on this information. + # This is only applicable to models that contain these necessary attributes. + forward_context.layer_idx = None + if has_layer_idx(model_instance): + forward_context.layer_idx = model_instance.model.start_layer + + # TODO(rjg-lyh): refactor mlp weight prefetch method + # set for mlp weight prefetch + prefetch_mlp_enabled = envs_ascend.vllm_npu_ENABLE_DENSE_OPTIMIZE and \ + envs_ascend.vllm_npu_ENABLE_PREFETCH_MLP and \ + forward_context.layer_idx is not None and \ + num_tokens is not None and num_tokens < 500 + if prefetch_mlp_enabled: + forward_context.prefetch_stream = prefetch_stream + forward_context.model_instance = model_instance + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled + forward_context.model_instance = model_instance + forward_context.weight_prefetch_method = weight_prefetch_method + + # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. + # It will be improved later by implementing operator fusion through the FX graph. + # + # set for addrmsnorm+quant fusion. + # this optim now just support dense models due to the specific operators used. + # Once the necessary conditions are met, support for MOE models will also be added. + from vllm_npu.quantization.quant_config import AscendQuantConfig + model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"] + addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ + vllm_config.model_config.hf_config.model_type in model_type_scope and \ + forward_context.layer_idx is not None + if addrmsnorm_quant_fusion_enabled: + forward_context.model_instance = model_instance + forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" + if vllm_config.model_config.hf_config.model_type == "qwen3_moe": + forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe" + forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled + + if num_tokens is None and attn_metadata is not None: + num_tokens = attn_metadata.num_actual_tokens + + dp_world_size = get_dp_group().world_size + if dp_world_size > 1 and forward_context.dp_metadata is not None: + max_tokens_across_dp = \ + forward_context.dp_metadata.max_tokens_across_dp_cpu.item() + if sp_enabled: + padded_length = (max_tokens_across_dp + tp_world_size - + 1) // tp_world_size * tp_world_size + pad_size = padded_length - num_tokens + forward_context.padded_length = padded_length + forward_context.pad_size = pad_size + else: + max_tokens_across_dp = num_tokens + + forward_context.max_tokens_across_dp = max_tokens_across_dp + + if num_tokens is not None: + if num_actual_tokens is None: + num_actual_tokens = num_tokens + # NOTE: token num which need to pad to when mc2 + forward_context.padded_num_tokens = math.ceil( + max_tokens_across_dp / tp_world_size) * tp_world_size + + if reserved_mc2_mask is not None: + mc2_mask = reserved_mc2_mask[:forward_context. + padded_num_tokens] + mc2_mask[:num_actual_tokens] = True + mc2_mask[num_actual_tokens:] = False + forward_context.mc2_mask = mc2_mask + + try: + yield + finally: + pass diff --git a/vllm_npu/attention/__init__.py b/vllm_npu/attention/__init__.py index d342775..e69de29 100644 --- a/vllm_npu/attention/__init__.py +++ b/vllm_npu/attention/__init__.py @@ -1 +0,0 @@ -"""Ascend NPU attention backends.""" diff --git a/vllm_npu/attention/attention_mask.py b/vllm_npu/attention/attention_mask.py new file mode 100644 index 0000000..2c963b5 --- /dev/null +++ b/vllm_npu/attention/attention_mask.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def _generate_attn_mask(max_seq_len, dtype): + # Construct lower triangle matrix. + mask_flag = torch.ones((max_seq_len, max_seq_len), + dtype=torch.bool).tril_() + # Create upper triangle matrix used to mark mask positions. + mask_flag = ~mask_flag + # Currently for fp16 dtype, the mask value should be set to -inf. + # TODO: Eliminate this part in the future. + mask_value = float('-inf') if dtype == torch.float16 else 1 + attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \ + .masked_fill_(mask_flag, mask_value) + return attn_mask + + +class AttentionMaskBuilder: + + def __init__( + self, + max_seq_len: int, + dtype: torch.dtype, + device: torch.device = None, + ): + # NOTE: The device argument specifies the target NPU + # to be used for the newly added FIA operator. + # Only pass this parameter when using the new FIA operator. + + attn_mask = _generate_attn_mask(max_seq_len, dtype) + + self._seq_len_cached = attn_mask.shape[0] + self.attn_mask_cache = attn_mask + self.device = device + self.pooling_mask = None + assigned_mask_dim = 2048 + self.chunked_prefill_attn_mask = torch.triu( + torch.ones(assigned_mask_dim, assigned_mask_dim), + diagonal=1).to(torch.int8).to(device) + + @staticmethod + def get_mask_scale_factor(dtype: torch.dtype = torch.float16): + if dtype == torch.float16: + mask_scale_factor = 1 + elif dtype == torch.bfloat16: + mask_scale_factor = -10000 + else: + raise ValueError( + "The current operation now only supports data types: torch.float16 and " + "torch.bfloat16. Please ensure the input is of one of these types." + ) + return mask_scale_factor + + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, + device: torch.device): + self._update_attn_cache(max_seq_len, dtype) + return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( + ).to(device, non_blocking=True) + + def get_pooling_mask(self, device): + if self.pooling_mask is None: + # the compressed attention mask for npu_fusion_attention sparse mode 4 + self.pooling_mask = torch.triu(torch.ones( + 2048, 2048), diagonal=1).to(torch.bool).to(device, + non_blocking=True) + return self.pooling_mask + + def get_splitfuse_attn_mask( + self, + seq_lens: torch.Tensor = None, + position: torch.Tensor = None, + dtype: torch.dtype = None, + device: torch.device = None, + ) -> torch.Tensor: + return self.chunked_prefill_attn_mask + + def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): + if seqlen > self._seq_len_cached: + self._seq_len_cached = seqlen + self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) + if self.attn_mask_cache.dtype != dtype: + self.attn_mask_cache = self.attn_mask_cache.to(dtype) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 9b6a3f3..807da61 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -1,76 +1,61 @@ -""" -Ascend NPU attention backend for vLLM v1. - -Implements the ``AttentionBackend``, ``AttentionMetadata``, -``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using -Huawei Ascend NPU FlashAttention operators: - -- ``torch_npu._npu_flash_attention`` — prefill attention (TND layout) -- ``torch_npu._npu_reshape_and_cache`` — KV cache update -- ``torch_npu._npu_paged_attention`` — paged-attention decode -""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# from dataclasses import dataclass -from enum import IntEnum -from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple +from enum import Enum +from typing import ClassVar, List, Optional, Tuple, Type import torch import torch.nn as nn - -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionType, -) -from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, -) +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionType) +from vllm.config import VllmConfig +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.utils import cdiv, direct_register_custom_op +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec -if TYPE_CHECKING: - from vllm.config import VllmConfig - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_npu.attention.utils import (AscendCommonAttentionMetadata, + maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector) +from vllm_npu.compilation.acl_graph import (get_graph_params, + update_graph_params_workspaces) +from vllm_npu.ops.attention import vanilla_chunked_prefill +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, + nd_to_nz_2d, nd_to_nz_spec) -logger = init_logger(__name__) - - -# ===================================================================== -# Attention state enum -# ===================================================================== - - -class AscendAttentionState(IntEnum): - """Attention computation state, determines the kernel path.""" - PrefillNoCache = 0 - PrefillCacheHit = 1 - DecodeOnly = 2 - ChunkedPrefill = 3 - - -# ===================================================================== -# Backend class -# ===================================================================== +from ..utils import weak_ref_tensors class AscendAttentionBackend(AttentionBackend): - """Ascend NPU FlashAttention backend.""" - accept_output_buffer: bool = True @staticmethod def get_name() -> str: - return "ASCEND_ATTN" + return "ASCEND" @staticmethod - def get_impl_cls() -> type["AttentionImpl"]: + def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: return AscendAttentionBackendImpl @staticmethod - def get_metadata_cls() -> type["AscendMetadata"]: + def get_metadata_cls() -> Type["AscendMetadata"]: return AscendMetadata @staticmethod @@ -83,15 +68,20 @@ class AscendAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, - **kwargs, ) -> Tuple[int, ...]: - """KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size). - - The leading ``2`` stores key and value caches in a single tensor. - They are split via ``kv_cache.unbind(0)`` at runtime. - """ + if is_310p(): + return (2, num_blocks, num_kv_heads * head_size // 16, block_size, + 16) return (2, num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_bsh_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads * head_size) @staticmethod def swap_blocks( @@ -99,180 +89,213 @@ class AscendAttentionBackend(AttentionBackend): dst_kv_cache: List[torch.Tensor], src_to_dst: torch.Tensor, ) -> None: - """Swap KV cache blocks between src and dst.""" - src_key_cache, src_value_cache = src_kv_cache - dst_key_cache, dst_value_cache = dst_kv_cache + src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] + dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] + src_indices = src_to_dst[:, 0] + dst_indices = src_to_dst[:, 1] - for src_idx, dst_idx in src_to_dst.tolist(): - dst_key_cache[dst_idx].copy_(src_key_cache[src_idx]) - dst_value_cache[dst_idx].copy_(src_value_cache[src_idx]) + dst_key_cache[dst_indices] = src_key_cache[src_indices].to( + dst_key_cache.device) + dst_value_cache[dst_indices] = src_value_cache[src_indices].to( + dst_key_cache.device) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dsts: torch.Tensor, + src_to_dists: torch.Tensor, ) -> None: - """Copy KV cache blocks in-place.""" - key_caches = [kv[0] for kv in kv_caches] - value_caches = [kv[1] for kv in kv_caches] + src_indices = src_to_dists[:, 0] + dst_indices = src_to_dists[:, 1] - for src_idx, dst_idx in src_to_dsts.tolist(): - for key_cache in key_caches: - key_cache[dst_idx].copy_(key_cache[src_idx]) - for value_cache in value_caches: - value_cache[dst_idx].copy_(value_cache[src_idx]) + for kv_cache in kv_caches: + key_caches = kv_cache[0] + value_caches = kv_cache[1] + key_caches[dst_indices] = key_caches[src_indices] + value_caches[dst_indices] = value_caches[src_indices] + + @staticmethod + def get_supported_block_size() -> list[int]: + return [128] -# ===================================================================== -# Metadata dataclass -# ===================================================================== +class AscendAttentionState(Enum): + PrefillNoCache = 0 + PrefillCacheHit = 1 + DecodeOnly = 2 + ChunkedPrefill = 3 + SpecDecoding = 4 @dataclass class AscendMetadata: - """Per-layer attention metadata for the Ascend backend.""" + # **************************** Basic Properties ************************** # + attn_mask: Optional[torch.Tensor] = None + # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + # Number of tokens excluding padding. num_actual_tokens: int = 0 - # Sequence lengths and query positions - seq_lens: Optional[torch.Tensor] = None # (batch,) - seq_lens_list: Optional[List[int]] = None - query_start_loc: Optional[torch.Tensor] = None # (batch+1,) - query_lens: Optional[torch.Tensor] = None + # The sequence length per sequence. Sequence length means the computed + # tokens + new tokens (is None if it is a decoding). + # (batch_size,) + # TODO(Angazenn): The following parameters are quite redundant and + # contains similar information (such as seq_lens seq_lens_list). We + # should simplified these parameters once attention schema in vLLM-Ascend + # is unified. + seq_lens: torch.Tensor = None + seq_lens_list: List[int] = None # type: ignore + actual_seq_lengths_q: List[int] = None # type: ignore + + query_start_loc: torch.Tensor = None + query_lens: torch.Tensor = None + # Maximum query length in the batch (None for decoding). max_query_len: Optional[int] = None - actual_seq_lengths_q: Optional[List[int]] = None # cumulative q positions - # KV cache mapping - block_tables: Optional[torch.Tensor] = None # (batch, max_blocks) - slot_mapping: Optional[torch.Tensor] = None # (num_tokens,) + # ********************** KV Cache Related Properties ********************* # + # Block addresses per sequence (Seq id -> list of physical block). + # (batch_size, max_blocks_per_seq) + block_tables: torch.Tensor = None - # Attention mask (for prefill causal masking) - attn_mask: Optional[torch.Tensor] = None + # The indices of the token slots that input tokens will be stored into. + # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the + # three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0, + # and 1st slot in block 1, respectively. + # (num_tokens,) + slot_mapping: torch.Tensor = None + + # *************************** Other Properties *************************** # + enable_dbo_across_dp: bool = False -# ===================================================================== -# Metadata builder -# ===================================================================== - - -class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): - """Builds ``AscendMetadata`` from ``CommonAttentionMetadata``.""" - - cudagraph_support: ClassVar[AttentionCGSupport] = ( +class AscendAttentionMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + # Does this backend/builder reorder the batch? + # If not, set this to None. Otherwise set it to the query + # length that will be pulled into the front of the batch. reorder_batch_threshold: ClassVar[int] = 1 def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: "VllmConfig", + vllm_config: VllmConfig, device: torch.device, ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.block_size = kv_cache_spec.block_size - self.num_kv_heads = kv_cache_spec.num_kv_heads - self.head_size = kv_cache_spec.head_size + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.max_num_blocks_per_req = cdiv( + self.model_config.max_model_len, + AscendAttentionBackend.get_supported_block_size()[0]) + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" - def reorder_batch( - self, - input_batch: "InputBatch", - scheduler_output: "SchedulerOutput", - ) -> bool: - """ - Reorder so decodes (query_len == 1) come first, prefills after. - """ - from vllm.v1.attention.backends.utils import ( - reorder_batch_to_split_decodes_and_prefills, - ) - return reorder_batch_to_split_decodes_and_prefills( - input_batch, scheduler_output, decode_threshold=1 - ) + def reorder_batch(self, input_batch, + scheduler_output: "SchedulerOutput") -> bool: + return False def build( self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> AscendMetadata: - """Build AscendMetadata from the common attention metadata.""" - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - - # Determine attention state + common_attn_metadata: AscendCommonAttentionMetadata, + model: Optional[nn.Module] = None, + ): num_reqs = common_attn_metadata.num_reqs - if max_query_len == 1: - attn_state = AscendAttentionState.DecodeOnly - else: - # Check if this is a pure prefill (no prior cache) or chunked - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] - query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs] - # PrefillNoCache: all requests have query_len == seq_len - if (query_lens_cpu == seq_lens_cpu).all(): - attn_state = AscendAttentionState.PrefillNoCache - else: - attn_state = AscendAttentionState.ChunkedPrefill + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + block_table = common_attn_metadata.block_table_tensor + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + attn_mask = common_attn_metadata.attn_mask + attn_state = common_attn_metadata.attn_state + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] - # Build cumulative sequence lengths for query (for prefill) - query_start_loc_cpu_full = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = common_attn_metadata.query_start_loc.to( - dtype=torch.int64 - ) - actual_seq_lengths_q = query_start_loc_cpu_full[1:].tolist() + if attn_state == AscendAttentionState.DecodeOnly and \ + common_attn_metadata.num_input_tokens > num_actual_tokens: + padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens + seq_lens = torch.cat([ + seq_lens, + torch.ones(padded_num_tokens, + dtype=seq_lens.dtype, + device=seq_lens.device) + ]) + block_table_padding = torch.zeros( + (padded_num_tokens, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], dim=0) + query_start_loc_cpu = torch.cat([ + query_start_loc_cpu, + torch.arange(query_start_loc_cpu[-1] + 1, + query_start_loc_cpu[-1] + padded_num_tokens, + dtype=query_start_loc_cpu.dtype, + device=query_start_loc_cpu.device) + ]) - seq_lens = common_attn_metadata.seq_lens - seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() + query_start_loc = query_start_loc_cpu.to(self.device, + non_blocking=True) - # Build attention mask for prefill (causal mask) - attn_mask = None - if attn_state != AscendAttentionState.DecodeOnly: - max_seq = common_attn_metadata.max_seq_len - attn_mask = torch.ones( - max_seq, - max_seq, - dtype=torch.bool, - device=self.device, - ).triu_(diagonal=1) + if is_310p(): + if attn_state == AscendAttentionState.PrefillNoCache: + mask_nz = nd_to_nz_2d(attn_mask) + attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + elif attn_state == AscendAttentionState.ChunkedPrefill: + mask_nz = nd_to_nz_spec(attn_mask) + attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), + ACL_FORMAT_FRACTAL_NZ) - return AscendMetadata( - attn_state=attn_state, + attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, + block_tables=block_table, query_start_loc=query_start_loc, - max_query_len=max_query_len, - actual_seq_lengths_q=actual_seq_lengths_q, - block_tables=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping, + query_lens=query_lens, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_query_len=common_attn_metadata.max_query_len, + actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), + slot_mapping=slot_mapping, attn_mask=attn_mask, - ) + attn_state=attn_state, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) + return attn_metadata - def build_for_cudagraph_capture( + def build_for_graph_capture( self, - common_attn_metadata: CommonAttentionMetadata, - ) -> AscendMetadata: - """Build metadata for graph capture (decode-only).""" - return self.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, + ): + if attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) - -# ===================================================================== -# Attention implementation -# ===================================================================== + attn_metadata.attn_state = attn_state + return attn_metadata class AscendAttentionBackendImpl(AttentionImpl): - """ - Ascend NPU attention kernel implementation. - - Uses ``torch_npu.npu_fusion_attention`` for prefill and - ``torch_npu.npu_incre_flash_attention`` for decode. - """ def __init__( self, @@ -283,9 +306,9 @@ class AscendAttentionBackendImpl(AttentionImpl): alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], **kwargs, ) -> None: self.num_heads = num_heads @@ -295,144 +318,17 @@ class AscendAttentionBackendImpl(AttentionImpl): self.hidden_size = self.num_heads * self.head_size self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window - self.attn_type = attn_type - if alibi_slopes is not None: - alibi_slopes = torch.tensor( - alibi_slopes, dtype=torch.float32, device="npu" - ) + alibi_slopes = torch.tensor(alibi_slopes, + dtype=torch.float32, + device="npu") self.alibi_slopes = alibi_slopes + self.attn_type = attn_type assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - # Cached references to the KV cache tensors - self._key_cache: Optional[torch.Tensor] = None - self._value_cache: Optional[torch.Tensor] = None - - def reshape_and_cache( - self, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: "AscendMetadata", - ): - """Update KV cache with new key/value tensors. - - Uses ``torch_npu._npu_reshape_and_cache`` for efficient in-place - KV cache update, matching vllm-ascend reference. - """ - import torch_npu # noqa: F401 - - if kv_cache.numel() > 0: - if self._key_cache is None: - self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] - - slots = attn_metadata.slot_mapping - num_actual = attn_metadata.num_actual_tokens - torch_npu._npu_reshape_and_cache( - key=key[:num_actual], - value=value[:num_actual], - key_cache=self._key_cache, - value_cache=self._value_cache, - slot_indices=slots, - ) - - return key, value - - def forward( - self, - layer: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with Ascend attention. - - Args: - query: (num_tokens, num_heads * head_size) - key: (num_tokens, num_kv_heads * head_size) - value: (num_tokens, num_kv_heads * head_size) - kv_cache: tensor of shape - (2, num_blocks, block_size, num_kv_heads, head_size) - attn_metadata: AscendMetadata for this forward call. - - Returns: - (num_tokens, num_heads * head_size) - """ - import torch_npu # noqa: F401 - - assert output is not None, "Output tensor must be provided." - num_tokens = query.shape[0] - - if attn_metadata is None: - return output.fill_(0) - - # Reshape Q/K/V to TND (tokens, heads, head_dim) - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - # TODO: Remove this contiguous in the future. - value = value.contiguous() - - # Step 1: Update KV cache - if key is not None and value is not None: - key, value = self.reshape_and_cache( - key, value, kv_cache, attn_metadata - ) - - # Step 2: Compute attention - if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - output = self._forward_decode( - query, attn_metadata, output, num_tokens - ) - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens - ) - else: - # ChunkedPrefill — use npu_fused_infer_attention_score - output = self._forward_chunked_prefill( - query, attn_metadata, output, num_tokens - ) - - return output - - # ----------------------------------------------------------------- - # Decode path — paged attention via _npu_paged_attention - # ----------------------------------------------------------------- - - def _forward_decode( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: torch.Tensor, - num_tokens: int, - ) -> torch.Tensor: - """Decode-only via _npu_paged_attention (matches vllm-ascend).""" - import torch_npu # noqa: F401 - - torch_npu._npu_paged_attention( - query=query, - key_cache=self._key_cache, - value_cache=self._value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output, - ) - return output - - # ----------------------------------------------------------------- - # Prefill without KV cache — _npu_flash_attention (TND layout) - # ----------------------------------------------------------------- + self.key_cache = None + self.value_cache = None def _forward_prefill_no_cache( self, @@ -440,61 +336,216 @@ class AscendAttentionBackendImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, - output: torch.Tensor, - num_tokens: int, + output: Optional[torch.Tensor] = None, + num_tokens=0, ) -> torch.Tensor: - """Prefill attention without KV cache via _npu_flash_attention. - - Uses TND layout and a pre-built causal mask from metadata. - This matches vllm-ascend's _forward_prefill_no_cache. - """ - import torch_npu # noqa: F401 + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None mask = attn_metadata.attn_mask - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output, - ) + if is_310p(): + # align q k v output tensors + query = aligned_16(query) + key = aligned_16(key) + value = aligned_16(value) + output = aligned_16(output) + # do reformat in case of broadcasted tensors + mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) + mask = torch_npu.npu_format_cast(mask.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + assert output is not None return output[:num_tokens, :, :] - # ----------------------------------------------------------------- - # Chunked prefill — npu_fused_infer_attention_score (TND layout) - # ----------------------------------------------------------------- - - def _forward_chunked_prefill( + def _forward_prefill_cache_hit( self, query: torch.Tensor, attn_metadata: AscendMetadata, - output: torch.Tensor, - num_tokens: int, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Chunked prefill / mixed prefill+decode via - npu_fused_infer_attention_score, matching vllm-ascend's - _forward_v1_style.""" - import torch_npu # noqa: F401 - - assert self._key_cache is not None + assert attn_metadata is not None assert attn_metadata.attn_mask is not None - num_block, block_size, _, _ = self._key_cache.shape - key = self._key_cache.view(num_block, block_size, -1) - value = self._value_cache.view(num_block, block_size, -1) + compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] - # Trim query to actual tokens (npu_fused_infer_attention_score - # requires query.shape[0] == query_start_loc[-1]) - actual_num_tokens = attn_metadata.query_start_loc[-1] - q = query[:actual_num_tokens] + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_table, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + return output - out, _ = torch_npu.npu_fused_infer_attention_score( - query=q, + def _forward_decode_only( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if is_310p(): + # seq_lens_tensor needs to be transferred to the device for 310P. + attn_metadata.seq_lens = \ + attn_metadata.seq_lens.to(device=query.device) + if self.sliding_window is not None and attn_metadata.seq_lens.shape[ + 0] == query.size(0): + batch_size = attn_metadata.seq_lens.shape[0] + block_size = 128 + query = query.view(batch_size, 1, self.num_heads * self.head_size) + key = self.key_cache + value = self.value_cache + if self.key_cache is not None and self.value_cache is not None: + block_size = self.key_cache.shape[1] + key = self.key_cache.flatten(2, 3).contiguous() + value = self.value_cache.flatten(2, 3).contiguous() + + output, _ = torch_npu.npu_fused_infer_attention_score( + query, + key, + value, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSH", + block_size=block_size, + pre_tokens=self.sliding_window, + scale=self.scale, + block_table=attn_metadata.block_tables, + actual_seq_lengths=[1] * len(attn_metadata.seq_lens), + actual_seq_lengths_kv=attn_metadata.seq_lens) + + output = output.view(batch_size, self.num_heads, self.head_size) + else: + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + num_tokens = query.shape[0] + if forward_context.capturing: + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) + + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append(( + weak_ref_tensors(query), + weak_ref_tensors(self.key_cache), + weak_ref_tensors(self.value_cache), + self.num_kv_heads, + self.num_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + weak_ref_tensors(output), + )) + + torch.npu.graph_task_group_begin(stream) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output, + workspace=workspace) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + return output + + def _forward_v1_style( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Use chunked prefill for head size 192 scenario, like deepseek + # paged_attention_splitfuse maybe crash at such scenario. + # TODO: vanilla path will be removed after the kernel support + # head_size 192 scenario. + if self.head_size == 192: + cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() + cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() + cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device) + cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device) + cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) + cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) + max_seqlen_q = torch.max(attn_metadata.query_lens) + max_seqlen_k = torch.max(attn_metadata.seq_lens) + vanilla_chunked_prefill(output, query, self.key_cache, + self.value_cache, + attn_metadata.block_tables, cu_seqlen_q, + cu_seqlen_k, max_seqlen_q, max_seqlen_k, + self.scale, None, True) + return output + + # Use paged attention. + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + + if is_310p(): + # Do reformat in case of broadcasted tensors. + attn_metadata.attn_mask = \ + torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + attn_metadata.seq_lens = \ + attn_metadata.seq_lens.to(device=query.device) + + # TODO:The npu_fused_infer_attention_score op is planned to + # be utilized in a wider range in upcoming versions. + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, key=key, value=value, atten_mask=attn_metadata.attn_mask, @@ -509,6 +560,168 @@ class AscendAttentionBackendImpl(AttentionImpl): sparse_mode=3, ) - output[:actual_num_tokens, :, :] = out[:actual_num_tokens, :, :] return output + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + trace_flag: bool = True, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache: shape = [key_cache, value_cache] + key_cache = [num_blocks, block_size, + num_kv_heads, head_size] + value_cache = [num_blocks, block_size, + num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size * seq_len, num_heads, head_size] + """ + num_tokens = query.shape[0] + use_kv_cache_int8 = len( + kv_cache) > 0 and kv_cache[0].dtype == torch.int8 + if output is None: + output = torch.empty(num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + ori_output = output + if trace_flag: + torch.ops.vllm.unified_ascend_attention_with_output( + query=query, + key=key, + value=value, + output=output, + layer_name=layer.layer_name) + + elif hasattr(layer, 'quant_method') and use_kv_cache_int8: + output = layer.quant_method.apply(layer, query, key, value, + kv_cache, attn_metadata, + self.attn_type, self.scale, + output) + + else: + if attn_metadata is None: + return output.view(num_tokens, self.hidden_size).fill_(0) + num_actual_tokens = attn_metadata.num_actual_tokens + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + # View q k v to BSH. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + # TODO: Remove this contiguous in the future. + value = value.contiguous() + + if len(kv_cache) > 1: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache( + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots) + if attn_type == AttentionType.ENCODER_ONLY: + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + attn_out = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + ) + output = attn_out[0] + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + output = self._forward_prefill_cache_hit( + query, attn_metadata, output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + output = self._forward_decode_only(query, attn_metadata, + output) + # Normal V1 situation. + else: + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] + output = self._forward_v1_style(query, attn_metadata, output) + + # to make in-place change to the output tensor + if hasattr(layer, 'quant_method') and use_kv_cache_int8: + output = output.view(num_tokens, self.num_heads, self.head_size) + ori_output[:num_tokens, :, :] = output[:num_tokens, :, :] + return output.view(num_tokens, self.hidden_size) + + +def unified_ascend_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + kv_cache, + attn_metadata, + output, + trace_flag=False) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_ascend_attention_with_output", + op_func=unified_ascend_attention_with_output, + mutates_args=["output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_npu/attention/mla_v1.py b/vllm_npu/attention/mla_v1.py new file mode 100644 index 0000000..fdeb866 --- /dev/null +++ b/vllm_npu/attention/mla_v1.py @@ -0,0 +1,1326 @@ +from dataclasses import dataclass +from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, + TypeVar) + +import torch +import torch_npu +from torch import nn +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from vllm_npu import envs +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.attention.utils import (AscendCommonAttentionMetadata, + maybe_save_kv_layer_to_connector, + split_decodes_and_prefills, + trans_rope_weight, transdata, + wait_for_kv_layer_from_connector) +from vllm_npu.compilation.acl_graph import (get_graph_params, + update_graph_params_workspaces) +from vllm_npu.multistream.base import MSAttentionMetadataSplitConfig +from vllm_npu.multistream.context import get_multistream_comm_context +from vllm_npu.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_npu.ops.weight_prefetch import maybe_npu_prefetch +from vllm_npu.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, + is_enable_nz, weak_ref_tensors) +from vllm_npu.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendMLABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_MLA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendMLAMetadata + + @staticmethod + def get_builder_cls(): + return AscendMLAMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendMLAImpl + + +@dataclass +class AscendMLAPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor + + attn_mask: torch.Tensor + query_lens: torch.Tensor + seq_lens: list[int] + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + chunked_context: Optional[ChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None + + +@dataclass +class AscendMLADecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: Optional[list[int]] = None + attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None + + +@dataclass +class AscendMLAMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendMLADecodeMetadata] = None + prefill: Optional[AscendMLAPrefillMetadata] = None + enable_dbo_across_dp: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMLAMetadata"]: + """Split metadata for multi-stream with AscendMLAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + + +M = TypeVar("M", bound=AscendMLAMetadata) + + +class AscendMLAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendMLAMetadata] = None): + self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ + if metadata_cls is not None else AscendMLAMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + self.reorder_batch_threshold = self.decode_threshold + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= self.decode_threshold: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLAMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + model.model.start_layer].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + model.model.start_layer].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[reqs_start:].max().item() + max_seq_lens = seq_lens[reqs_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendMLAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + prefill_metadata = AscendMLAPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens[reqs_start:].to(torch.int32), + seq_lens=seq_lens, + context_lens=seq_lens[reqs_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + decode_metadata = None + if num_decodes > 0: + cos = common_attn_metadata.cos + sin = common_attn_metadata.sin + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes] + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + seq_lens_list = seq_lens.tolist() + + # TODO: After the fullgraph supports MTP, the if branch needs to deleted + assert self.cos_cache is not None + assert self.sin_cache is not None + if cos is None and sin is None: + cos = self.cos_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + else: + cos[:num_decode_tokens, + ...] = self.cos_cache[input_positions].unsqueeze( + 1).unsqueeze(2) + sin[:num_decode_tokens, + ...] = self.sin_cache[input_positions].unsqueeze( + 1).unsqueeze(2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin[:num_decode_tokens, ...], + cos=cos[:num_decode_tokens, ...]) + + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_input_tokens, + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + ) + + def build_for_graph_capture( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, + ): + if attn_state in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + }: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + model=model, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + + +class DecodeMLAPreprocessResult(NamedTuple): + ql_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + + +class PrefillMLAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + value: Optional[torch.Tensor] = None + + +class AscendMLAImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None) + self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ + 'q_b_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.weight_prefetch_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + self.speculative_config = vllm_config.speculative_config + self.enable_mlapo = envs.vllm_npu_ENABLE_MLAPO + + def _v_up_proj(self, x): + if x.dtype in [torch.float16, torch.bfloat16] \ + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"): + x = x.view(-1, self.num_heads, self.kv_lora_rank) + b, _, _ = x.shape + res = torch.empty((b, self.num_heads, self.v_head_dim), + dtype=x.dtype, + device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + # Weight will be reshaped next. To be on the safe side, the format + # of the weight should be reverted to FRACTAL_AND. + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_ND) + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + + # Function `get_and_maybe_dequant_weights` will cast the weights to + # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. + if is_enable_nz(self.kv_b_proj.weight.data.dtype): + self.kv_b_proj.weight.data = torch_npu.npu_format_cast( + self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + # Currently mlapo only supports W8A8 quantization in MLA scenario + # TODO(whx): modify this limitation when mlapo supports floating point + if self.fused_qkv_a_proj is None or not isinstance( + getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', + None), AscendW8A8LinearMethod): + self.enable_mlapo = False + logger.warning_once( + "Currently mlapo only supports W8A8 quantization in MLA scenario." + "Some layers in your model are not quantized with W8A8," + "thus mlapo is disabled for these layers.") + if self.enable_mlapo: + self._process_weights_for_fused_mlapo(act_dtype) + + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): + kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ + ..., self.q_lora_rank:].contiguous() + q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ + ..., :self.q_lora_rank].contiguous() + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) + wd_qkv = wd_qkv.t().contiguous() + wd_qkv = transdata(wd_qkv, + block_size=(16, 32)).unsqueeze(0).contiguous() + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + + kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[ + self.q_lora_rank:].contiguous() + q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self. + q_lora_rank].contiguous( + ) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, + self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), + dim=-1).contiguous() + + kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[ + self.q_lora_rank:].contiguous() + q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self. + q_lora_rank].contiguous( + ) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, + self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), + dim=-1).contiguous() + + wu_q = self.q_proj.weight.data + wu_q = wu_q.t().reshape(self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + -1) + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) + wu_q = wu_q.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), + -1) + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + + qb_deq_scl = self.q_proj.deq_scale.data + qb_deq_scl = qb_deq_scl.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) + self.qb_deq_scl = qb_deq_scl.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + qb_qt_bias = self.q_proj.quant_bias.data + qb_qt_bias = qb_qt_bias.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) + self.qb_qt_bias = qb_qt_bias.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + device = self.q_proj.weight.device + self.gamma1 = self.q_a_layernorm.weight.data + self.beta1 = self.q_a_layernorm.bias.data + self.gamma2 = self.kv_a_layernorm.weight.data + self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data + self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data + self.quant_scale1 = self.q_proj.input_scale.data + self.quant_offset1 = self.q_proj.input_offset.data + self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) + self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + + def _compute_prefill_context( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + rope_dim: int, + attn_metadata: AscendMLAMetadata, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + ): + assert len(kv_c_and_k_pe_cache) > 1 + prefill_metadata = attn_metadata.prefill + if prefill_metadata is None or prefill_metadata.chunked_context is None: + return prefix_output, prefix_lse + + iters = len(prefill_metadata.chunked_context.seq_tot) + + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) + kv_c_normed = torch.empty(toks, + num_heads, + latent_kv_dim, + dtype=q_nope.dtype, + device=q_nope.device) + k_pe = torch.empty(toks, + num_heads, + rope_dim, + dtype=q_nope.dtype, + device=q_nope.device) + + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + context_seq_len_npu, + seq_starts=prefill_metadata.chunked_context.starts[i], + key=kv_c_normed, + value=k_pe, + ) + + kv_c_normed = kv_c_normed.squeeze() + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=self.prefill_mask, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) + return prefix_output, prefix_lse + + def _forward_prefill( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + value: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 + num_tokens = q_nope.size(0) + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=q_nope.dtype, + device=q_nope.device) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=q_nope.device) + if self.prefill_mask is None: + if q_nope.dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(q_nope.dtype) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=attn_metadata.prefill.query_lens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + + attn_output = attn_output.reshape( + [num_tokens, self.num_heads * self.v_head_dim]) + return attn_output + + def exec_kv_decode( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view( + B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + return k_pe, k_nope + + def exec_kv_prefill( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view( + B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, + ) + return k_pe, k_nope + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + block_size: int, + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + decode_meta = attn_metadata.decode + assert decode_meta is not None + num_tokens = q_nope.size(0) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + actual_seq_lengths = None + if self.enable_kv_nz: + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" + + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.DecodeOnly, + ] and self.speculative_config is not None: + # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill + input_layout = "TND" + # [bs * q_seq_len, num_heads_per_rank, dim] + # TODO: If the driver is upgraded later, the contiguous function can be deleted. + q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous() + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, + -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, + -1).contiguous() + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + + common_kwargs = { + 'query_rope': q_pe, + 'key_rope': k_pe, + 'num_heads': self.num_heads, + 'num_key_value_heads': self.num_kv_heads, + 'input_layout': input_layout, + 'atten_mask': spec_attn_mask, + 'sparse_mode': sparse_mode, + 'scale': self.scale, + 'antiquant_mode': 0, + 'antiquant_scale': None, + 'block_table': decode_meta.block_table, + 'block_size': block_size, + "actual_seq_lengths": actual_seq_lengths, + "actual_seq_lengths_kv": decode_meta.seq_lens_list, + } + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + if forward_context.capturing: + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, k_nope, k_nope, **common_kwargs) + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) + + attn_output = torch.empty_like(q_nope) + softmax_lse = torch.empty(num_tokens, + dtype=q_nope.dtype, + device=q_nope.device) + + graph_params.attn_params[num_tokens].append( + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), + self.num_heads, self.num_kv_heads, input_layout, + weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None + else None, sparse_mode, self.scale, decode_meta.block_table, + block_size, decode_meta.seq_lens_list, actual_seq_lengths, + weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) + + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + **common_kwargs, + workspace=workspace, + out=[attn_output, softmax_lse]) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + else: + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + q_nope, k_nope, k_nope, **common_kwargs) + + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj(attn_output) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj(attn_output) + + def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): + bsz = attn_metadata.num_decode_tokens + hidden_states = hidden_states[:bsz] + + cos_shape = attn_metadata.decode.cos.shape + cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) + sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) + + decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] + decode_q_nope = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], + decode_k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + decode_q_pe = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], + decode_k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + decode_k_nope, + decode_k_pe, + attn_metadata.slot_mapping[:bsz].flatten(), + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + q_out0=decode_q_nope, + kv_cache_out0=decode_k_nope, + q_out1=decode_q_pe, + kv_cache_out1=decode_k_pe, + ) + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) + return decode_preprocess_res, None + + def _mla_preprocess(self, layer_name, hidden_states, kv_cache, + attn_metadata, need_gather_q_kv): + # MLA Preprocess: + # 1. Perform q_a_proj and q_a_layernorm to obtain q_c + # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split + # 3. If need_gather_q_kv, perform all_gather. + # 4. Preprocess decode tokens, write kv cache and get: + # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope + # 5. Preprocess prefill tokens, write kv cache and get: + # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + if self.fused_qkv_a_proj is not None: + maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, + dependency=hidden_states, + enabled=self.enable_prefetch) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + # allgather need contiguous data + kv_no_split = kv_no_split.contiguous() + else: + q_c = hidden_states + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + + # Process for Flash Comm V1 + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + q_c, need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split, need_gather_q_kv) + + decode_preprocess_res = None + prefill_preprocess_res = None + if has_prefill: + wait_for_kv_layer_from_connector(layer_name) + # Preprocess for decode tokens + if has_decode: + decode_q_c = q_c[:num_decode_tokens] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_q_c) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + decode_slots = attn_metadata.slot_mapping[:num_decode_tokens] + decode_kv_no_split = kv_no_split[:num_decode_tokens] + decode_k_pe, decode_k_nope = self.exec_kv_decode( + decode_kv_no_split, cos, sin, kv_cache, decode_slots) + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) + # Preprocess for prefill tokens + if has_prefill: + prefill_kv_no_split = kv_no_split[ + num_decode_tokens:num_actual_tokens] + prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] + prefill_q = self.q_proj(prefill_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + prefill_slots = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( + prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], + self.num_kv_heads, -1) + prefill_k_nope, prefill_value = self.kv_b_proj( + prefill_k_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + prefill_k_pe = prefill_k_pe.expand( + (*prefill_k_nope.shape[:-1], -1)) + prefill_preprocess_res = PrefillMLAPreprocessResult( + prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, + prefill_value) + return decode_preprocess_res, prefill_preprocess_res + + def forward( + self, + layer_name, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + num_actual_tokens = attn_metadata.num_actual_tokens + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + num_decode_tokens = attn_metadata.num_decode_tokens + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + o_proj_input_shape = (get_forward_context().num_tokens, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + # MLA Preprocess + forward_context = get_forward_context() + if (self.enable_mlapo and + (attn_metadata is None or not forward_context.with_prefill)): + decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( + hidden_states, kv_cache, attn_metadata) + else: + decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( + layer_name, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv) + + if decode_preprocess_res is not None: + # MLA Preprocess for decoding + output_decode = self._forward_decode(decode_preprocess_res.ql_nope, + decode_preprocess_res.q_pe, + decode_preprocess_res.k_nope, + decode_preprocess_res.k_pe, + kv_cache[0].shape[1], + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + o_proj_input[:num_decode_tokens] = output_decode + current_ms_metadata.after_comm_event.record() + else: + o_proj_input[:num_decode_tokens] = output_decode + + if prefill_preprocess_res is not None: + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + output_prefill = self._forward_prefill( + prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, + prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, + prefill_preprocess_res.value, kv_cache, attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + o_proj_input[num_decode_tokens:] = output_prefill + current_ms_metadata.after_comm_event.record() + else: + o_proj_input[ + num_decode_tokens:num_actual_tokens] = output_prefill + # O proj + current_ms_metadata = get_multistream_comm_context() + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 + if current_ms_metadata is None: + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) + + output[...] = self.o_proj(o_proj_input, + is_prefill=prefill_preprocess_res + is not None)[0] + else: + with torch.npu.stream(current_ms_metadata.comm_stream): + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) + output[...] = self.o_proj(o_proj_input, + is_prefill=prefill_preprocess_res + is not None)[0] + current_ms_metadata.after_comm_event.record() + del o_proj_input + + has_prefill = attn_metadata.num_prefills > 0 + if has_prefill: + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) + return output_padded diff --git a/vllm_npu/attention/sfa_v1.py b/vllm_npu/attention/sfa_v1.py new file mode 100644 index 0000000..a665c3c --- /dev/null +++ b/vllm_npu/attention/sfa_v1.py @@ -0,0 +1,988 @@ +from dataclasses import dataclass +from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, + TypeVar) + +import torch +import torch_npu +from torch import nn +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.attention.mla_v1 import AscendMLAMetadata +from vllm_npu.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_npu.multistream.base import MSAttentionMetadataSplitConfig +from vllm_npu.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_npu.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFAMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFAMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["AscendSFAImpl"]: + return AscendSFAImpl + + +@dataclass +class AscendSFAPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] + seq_lens: list[int] + + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + sin: torch.Tensor + cos: torch.Tensor + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class AscendSFADecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendSFAMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendSFADecodeMetadata] = None + prefill: Optional[AscendSFAPrefillMetadata] = None + enable_dbo_across_dp: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendSFAMetadata"]: + """Split metadata for multi-stream with AscendSFAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + + +M = TypeVar("M", bound=AscendSFAMetadata) + + +class AscendSFAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFAMetadata] = None): + self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFAMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= self.decode_threshold: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFAMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + model.model.start_layer].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + model.model.start_layer].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[reqs_start:].max().item() + max_seq_lens = seq_lens[reqs_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendSFAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + actual_query_lens = torch.tensor(query_lens[reqs_start:], + dtype=torch.int32).npu() + query_lens_prefill_sfa = torch.cumsum(actual_query_lens, + dim=0).to(torch.int32) + seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu() + prefill_metadata = AscendSFAPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens_prefill_sfa, + seq_lens=seq_lens_prefill_sfa, + context_lens=seq_lens[reqs_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + decode_metadata = None + if num_decodes > 0: + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( + torch.int32).npu() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + seq_lens_list = seq_lens.tolist() + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendSFADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_input_tokens, + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + ) + + +class PrefillSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + + +class DecodeSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + # nope_cache: Optional[torch.Tensor] = None + # rope_cache: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + bsz: Optional[int] = None + + +class AscendSFAImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + if self.q_a_proj is not None: + self.q_b_proj = self.q_proj + else: + self.q_b_proj = None + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + # indexer param + self.dim = self.indexer.dim + self.n_heads: int = self.indexer.n_heads # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.index_topk: int = self.indexer.index_topk # 2048 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + self.softmax_scale = self.indexer.softmax_scale + + # Adapt torch air graph mode with spec decoding. + speculative_config = vllm_config.speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + self.cp_size = 1 + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + # SFA Preprocess: + # 1. Perform q_a_proj and q_a_layernorm to obtain q_c + # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split + # 3. If need_gather_q_kv, perform all_gather. + # 4. Preprocess decode tokens, write kv cache and get: + # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope + # 5. Preprocess prefill tokens, write kv cache and get: + # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + if need_gather_q_kv: + # q_c = get_tp_group().all_gather(q_c, 0) + # kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + hidden_states = get_tp_group().all_gather(hidden_states, 0) + # hidden_states_decode = hidden_states[:num_decode_tokens] + # if self.q_a_proj is not None: + # npu_prefetch(self.q_a_proj.weight, + # hidden_states, + # enabled=self.enable_prefetch) + # ckq = self.q_a_proj(hidden_states) # q down + # q_c = self.q_a_layernorm(ckq) # q down layernorm + # else: + # q_c = hidden_states + + # kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv + # Process for shared_expert_dp + + decode_preprocess_res = None + prefill_preprocess_res = None + # Preprocess for decode tokens + if has_decode: + q_len = 1 + hidden_states_decode = hidden_states[:num_decode_tokens] + decode_kq = self.q_a_proj(hidden_states_decode) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + decode_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_decode) # c_kv + + # decode_q_c = q_c[:num_decode_tokens] + decode_slot_mapping = attn_metadata.slot_mapping[: + num_decode_tokens] + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + decode_q = self.q_b_proj(decode_q_c) + bsz, _ = decode_q.shape + decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim) + decode_q_nope, decode_q_pe = torch.split( + decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + decode_q_nope = decode_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + decode_q_nope = (torch.matmul(decode_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(bsz, q_len, + self.num_heads, + self.kv_lora_rank)) + + # stream2 kv + key_cache = kv_cache[0] + value_cache = kv_cache[1] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1) + decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + decode_kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + decode_slot_mapping.to(torch.int64), + value_cache, + key_cache, + c_kv_scale=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode='PA') # adapter NZ + # nz_block_size = 16 + # KVCACHE_NZ_DIM = 16 + # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) + # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) + + decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos, + sin) # BNSD + + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + topk_indices = self.indexer_select(hidden_states_decode, + decode_q_c, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + kv_cache=kv_cache) + + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_rope) + decode_preprocess_res = DecodeSFAPreprocessResult( + q_nope=decode_q_nope, + q_pe=decode_q_pe, + # nope_cache = nope_cache, + # rope_cache = rope_cache, + topk_indices=topk_indices, + query_states=query_states, + key_states=key_states, + bsz=bsz, + ) + + # Preprocess for prefill tokens + if has_prefill: + bsz = 1 + + hidden_states_prefill = hidden_states[ + num_decode_tokens:num_actual_tokens] + prefill_kq = self.q_a_proj(hidden_states_prefill) # q down + prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm + prefill_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_prefill) # c_kv + + # prefill_q_c = q_c[ + # num_decode_tokens:num_actual_tokens] + prefill_slot_mapping = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + prefill_slot_mapping = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + # prefill_kv_no_split = kv_no_split[ + # num_decode_tokens:num_actual_tokens] + # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] + prefill_qr = prefill_q_c + prefill_q = self.q_b_proj(prefill_qr) + prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) + prefill_q_nope, prefill_q_pe = torch.split( + prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + prefill_q_nope = prefill_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + prefill_q_nope = (torch.matmul(prefill_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(-1, self.num_heads, + self.kv_lora_rank)) + prefill_q_pe = prefill_q_pe.unsqueeze(2) + + # stream2 kv + + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + cos_q, sin_q = cos, sin + + # cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + # sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + prefill_q_pe = torch_npu.npu_interleave_rope( + prefill_q_pe, cos_q, sin_q) # BNSD + prefill_q_pe = prefill_q_pe.squeeze(2) #BSH + # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? + + prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) + prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + prefill_latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + self.kv_a_layernorm.weight, + cos.view(-1, 1, 1, self.qk_rope_head_dim), + sin.view(-1, 1, 1, self.qk_rope_head_dim), + prefill_slot_mapping.to(torch.int64), + rope_cache, + nope_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA") + + topk_indices = self.indexer_select(x=hidden_states_prefill, + qr=prefill_qr, + kv_cache=kv_cache, + cos=cos, + sin=sin, + attn_metadata=attn_metadata) + query_states = (prefill_q_nope, prefill_q_pe) + key_states = (prefill_k_nope, prefill_k_pe) + prefill_preprocess_res = PrefillSFAPreprocessResult( + q_nope=prefill_q_nope, + q_pe=prefill_q_pe, + topk_indices=topk_indices, + k_nope=prefill_k_nope, + k_pe=prefill_k_pe, + query_states=query_states, + key_states=key_states, + ) + + return decode_preprocess_res, prefill_preprocess_res + + def forward( + self, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + num_actual_tokens = attn_metadata.num_actual_tokens + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + num_decode_tokens = attn_metadata.num_decode_tokens + # Inputs and outputs may be padded for CUDA graphs + output = output[:num_actual_tokens, ...] + o_proj_input_shape = (num_actual_tokens, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + # SFA Preprocess + decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess( + hidden_states, kv_cache, attn_metadata, need_gather_q_kv) + + if decode_preprocess_res is not None: + # bsz, q_len, _, _ = query_states[0].shape + decode_attn_output = self.apply_attention_fusion( + query_states=decode_preprocess_res.query_states, + key_states=decode_preprocess_res.key_states, + attn_metadata=attn_metadata, + topk_indices=decode_preprocess_res.topk_indices) + o_proj_input[:num_decode_tokens] = decode_attn_output + + if prefill_preprocess_res is not None: + prefill_attn_output = self.apply_attention_fusion( + query_states=prefill_preprocess_res.query_states, + key_states=prefill_preprocess_res.key_states, + attn_metadata=attn_metadata, + topk_indices=prefill_preprocess_res.topk_indices) + o_proj_input[num_decode_tokens:] = prefill_attn_output + + output[...] = self.mla_epilog(o_proj_input, absorb=True) + return output + + def apply_attention_fusion(self, query_states, key_states, topk_indices, + attn_metadata: M): + # repeat k/v heads if n_kv_heads < n_heads + q_nope, q_pe = query_states + k_nope, k_rope = key_states + + if attn_metadata.prefill is not None: + + prefill_metadata = attn_metadata.prefill + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=prefill_metadata.block_table, + actual_seq_lengths_query=prefill_metadata.query_lens, + actual_seq_lengths_kv=prefill_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + + elif attn_metadata.decode is not None: + decode_metadata = attn_metadata.decode + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.decode.block_table, + actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=decode_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.squeeze(1) + + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul(slc_fa_fusion, + self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + # Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and + # with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1. + # after reshape: [T(bs*q_len), 1, N//attn_tp_size*D] + # attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim) + + return attn_output + + def mla_epilog(self, + attn_output: torch.Tensor = None, + absorb: bool = False): + # TODO: need to check + attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0], + -1), + is_prefill=True, + is_force_scatter=False) + + return attn_output + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + cos, + sin, + attn_metadata: M, + ): + if attn_metadata.prefill is not None: + actual_seq_lengths_query = attn_metadata.prefill.query_lens + actual_seq_lengths_key = attn_metadata.prefill.seq_lens + block_table = attn_metadata.prefill.block_table + elif attn_metadata.decode is not None: + actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q + actual_seq_lengths_key = attn_metadata.decode.seq_lens + block_table = attn_metadata.decode.block_table + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights = self.weights_proj(x) + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices diff --git a/vllm_npu/attention/utils.py b/vllm_npu/attention/utils.py new file mode 100644 index 0000000..1ad81c0 --- /dev/null +++ b/vllm_npu/attention/utils.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass +from typing import Any, List + +import torch +import torch.nn.functional as F +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.forward_context import ForwardContext, get_forward_context + + +@dataclass +class AscendCommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" + + seq_lens_cpu: torch.Tensor + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + + seq_lens: torch.Tensor + """same to seq_lens_cpu, for compatibility with some new attn metadata + (such as GDN).""" + + num_computed_tokens_cpu: torch.Tensor + """(batch_size,), the number of computed tokens for each request""" + + num_reqs: int + """Number of requests""" + num_actual_tokens: int + """Total number of tokens in batch""" + + max_query_len: int + """Max token number of request in batch""" + + decode_token_per_req: int + """decode token number per request""" + + block_table_tensor: torch.Tensor + + slot_mapping: torch.Tensor + + actual_seq_lengths_q: list[int] + + positions: torch.Tensor = None + + attn_mask: torch.Tensor = None + + spec_attn_mask: torch.Tensor = None + + attn_state: Any = None + + enable_dbo_across_dp: bool = False + + is_only_prefill: bool = False + + graph_pad_size: int = -1 + + # num_input_tokens refers to total number of tokens including + # padding tokens. It is used to handle some padding operations. + num_input_tokens: int = 0 + + # NOTE: This is a temporary solution for rotary embedding in MLA + cos: torch.Tensor = None + sin: torch.Tensor = None + + +def split_decodes_and_prefills( + common_attn_metadata: AscendCommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: AscendCommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def trans_rope_weight(weight, rope_dim): + if rope_dim == 0: + return weight.contiguous() + nope_part = weight[..., :-rope_dim, :] + rope_part = weight[..., -rope_dim:, :] + reordered_rope_part = torch.cat( + (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2) + return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous() + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad)) + nz_mat = torch.permute( + torch.reshape( + nd_mat, + (r // block_size[0], block_size[0], c // block_size[1], + block_size[1]), + ), + [2, 0, 1, 3], + ) + nz_mat = torch.reshape( + nz_mat, + (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + return nz_mat diff --git a/vllm_npu/compilation/__init__.py b/vllm_npu/compilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/compilation/acl_graph.py b/vllm_npu/compilation/acl_graph.py new file mode 100644 index 0000000..92f5532 --- /dev/null +++ b/vllm_npu/compilation/acl_graph.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from dataclasses import dataclass +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch_npu +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphOptions +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import logger +from vllm.platforms import current_platform + +from ..utils import weak_ref_tensors + + +@dataclasses.dataclass +class ACLGraphEntry: + batch_descriptor: BatchDescriptor + aclgraph: Optional[torch.npu.NPUGraph] = None + output: Optional[Any] = None + + # for aclgraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class ACLGraphWrapper: + """Wraps a runnable to add acl graph capturing and replaying ability. And + provide attribute access to the underlying `runnable` via `__getattr__`. + + The workflow of this wrapper in the aclgraph dispatching is as follows: + 1. At initialization, a runtime mode is assigned to the wrapper (FULL or + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a + batch_descriptor(key) from the forward context and blindly trust them + for aclgraph dispatching. + 3. If runtime_mode is NONE or runtime_mode does not match the mode of the + wrapper, just call the runnable directly. + 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, + the wrapper will perform aclgraph capture(if key does not exist, create + a new entry and cache it) or replay (if key exists in the cache). + + Note: ACLGraphWrapper does not store persistent buffers or copy any + runtime inputs into that buffers for replay. We assume implementing them + is done outside of the wrapper. That is because we do not make any + assumption on the dynamic shape (batch size) of the runtime inputs, as a + trade-off for staying orthogonal to compilation logic. Nevertheless, + tracing and checking the input addresses to be consistent during replay is + guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". + """ + + def __init__(self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + graph_pool: Any = None, + cudagraph_options: Optional[CUDAGraphOptions] = None): + self.runnable = runnable + self.vllm_config = vllm_config + self.graph_pool = graph_pool + self.runtime_mode = runtime_mode + self.compilation_config = vllm_config.compilation_config + + self.first_run_finished = False + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # assert runtime_mode is not NONE(no aclgraph), otherwise, we don't + # need to initialize a ACLGraphWrapper. + assert self.runtime_mode != CUDAGraphMode.NONE + if self.graph_pool is None: + self.graph_pool = current_platform.get_global_graph_pool() + + if cudagraph_options is None: + cudagraph_options = CUDAGraphOptions() + self.aclgraph_options = cudagraph_options + # the entries for different batch descriptors that we need to capture + # aclgraphs for. + self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\ + = {} + + def __getattr__(self, key: str): + # allow accessing the attributes of the runnable. + if hasattr(self.runnable, key): + return getattr(self.runnable, key) + raise AttributeError(f"Attribute {key} not exists in the runnable of " + f"aclgraph wrapper: {self.runnable}") + + def unwrap(self) -> Callable: + # in case we need to access the original runnable. + return self.runnable + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode + + if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ + aclgraph_runtime_mode != self.runtime_mode: + # CUDAGraphMode.NONE could mean the profile run, a warmup run, or + # running without aclgraphs. + # We do not trigger capture/replay if the runtime mode is not + # matches. This enables properly dispatching to the correct + # CUDAGraphWrapper when nesting multiple instances with different + # runtime modes. + return self.runnable(*args, **kwargs) + + if batch_descriptor not in self.concrete_aclgraph_entries: + # create a new entry for this batch descriptor + self.concrete_aclgraph_entries[batch_descriptor] = \ + ACLGraphEntry(batch_descriptor=batch_descriptor) + + entry = self.concrete_aclgraph_entries[batch_descriptor] + + if entry.aclgraph is None: + if self.aclgraph_options.debug_log_enable: + # Since we capture aclgraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. E.g. we only log it for the first subgraph in + # piecewise mode. + logger.debug("Capturing a aclgraph on (%s,%s)", + self.runtime_mode.name, entry.batch_descriptor) + # validate that aclgraph capturing is legal at this point. + validate_cudagraph_capturing_enabled() + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + aclgraph = torch.npu.NPUGraph() + + with ExitStack() as stack: + if self.aclgraph_options.gc_disable: + # during every model forward for piecewise aclgraph + # mode, we will capture many pieces of aclgraphs + # (roughly one per layer). running gc again and again + # across layers will make the aclgraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.npu.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + forward_context.capturing = True + with torch.npu.graph(aclgraph, pool=self.graph_pool): + # `output` is managed by pytorch's aclgraph pool + output = self.runnable(*args, **kwargs) + if self.aclgraph_options.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph in piecewise aclgraph mode, because + # the output of the last graph will not be used by + # any other acl graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.aclgraph = aclgraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during acl graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + f"Input addresses for aclgraphs are different " + f"during replay. Expected {entry.input_addresses}, " + f"got {new_input_addresses}") + + logger.info_once("Replaying aclgraph") + entry.aclgraph.replay() + return entry.output + + +def update_attn_params(update_stream, + forward_context, + runtime_shape, + kv_transfer_config=None): + graph_params = get_graph_params() + + # NOTE(Angazenn): By moving the npu-stream context ahead, + # (see https://github.com/vllm-project/vllm-ascend/pull/3985) + # we can reduce host overhead introduced by stream initialization. + # However, we find that this might cause potential accuracy problems + # with pd-disaggreagation. Therefore, this optimization is only enabled + # without pd-disaggreagation. We are working on to solve this problem + # directly int the future. + if kv_transfer_config is not None: + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens + + # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY + # mode with GQA. This is triggered by getting workspace for _npu_paged_attention + # in torch_npu. On some cases, _npu_paged_attention requires different workspace + # among various seq_lens. So additional get_workspace is added here + # to avoid such bugs. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + else: + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens + + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def update_mla_attn_params(update_stream, forward_context, runtime_shape, + speculative_config): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + seq_lens_list, actual_seq_lengths, attn_output, + softmax_lse) = param + seq_lens_list = forward_context.attn_metadata[ + key].decode.seq_lens_list + if speculative_config and speculative_config.method == "deepseek_mtp": + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * ( + runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [ + spec_multiple * (i + 1) + for i in range(runtime_shape // spec_multiple) + ] + else: + seq_lens_list = seq_lens_list + [0] * (runtime_shape - + len(seq_lens_list)) + + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens_list, + actual_seq_lengths=actual_seq_lengths, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse]) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def update_graph_params_workspaces(num_tokens: int, workspace: Any): + global _graph_params + if _graph_params is not None: + _graph_params.workspaces[num_tokens] = workspace + + +def get_graph_params(): + return _graph_params diff --git a/vllm_npu/core/__init__.py b/vllm_npu/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/core/recompute_schedule_config.py b/vllm_npu/core/recompute_schedule_config.py new file mode 100644 index 0000000..0bad7df --- /dev/null +++ b/vllm_npu/core/recompute_schedule_config.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from dataclasses import dataclass, fields +from typing import Type, Union + +from vllm.config import SchedulerConfig + +MAX_INT = 2147483647 + + +@dataclass +class RecomputeSchedulerConfig(SchedulerConfig): + scheduler_cls: Union[str, Type[object]] = ( + "vllm_npu.core.recompute_scheduler.RecomputeScheduler") + + @classmethod + def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig): + scheduler_config = { + field.name: getattr(vllm_scheduler_config, field.name) + for field in fields(vllm_scheduler_config) if field.init + } + scheduler_config["scheduler_cls"] = ( + "vllm_npu.core.recompute_scheduler.RecomputeScheduler") + return cls(**scheduler_config) diff --git a/vllm_npu/core/recompute_scheduler.py b/vllm_npu/core/recompute_scheduler.py new file mode 100644 index 0000000..1104273 --- /dev/null +++ b/vllm_npu/core/recompute_scheduler.py @@ -0,0 +1,1392 @@ +## +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from __future__ import annotations + +import itertools +import time +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +from vllm.config import VllmConfig +from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ + KVConnectorStats +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, + compute_encoder_budget) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData +from vllm.v1.core.sched.request_queue import (SchedulingPolicy, + create_request_queue) +from vllm.v1.core.sched.utils import check_stop, remove_all +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs, FinishReason) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import ConstantList + +logger = init_logger(__name__) + + +class RecomputeScheduler(SchedulerInterface): + """This Scheduler extends vllm's original v1 scheduler of version 0.11 + to fix recomputing bug.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.kv_cache_config = kv_cache_config + self.kv_events_config = vllm_config.kv_events_config + self.parallel_config = vllm_config.parallel_config + self.log_stats = log_stats + self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder + + # include_finished_set controls whether a separate set of finished + # request ids should be included in the EngineCoreOutputs returned + # by update_from_outputs(). This is currently used in the multi-engine + # case to track request lifetimes efficiently. + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + self.enable_kv_cache_events = ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events) + + # Create KVConnector for the Scheduler. Note that each Worker + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None + if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " + "with KV connectors") + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported " + "with KV connectors") + self.connector = KVConnectorFactory.create_connector( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config, + self.parallel_config.data_parallel_rank, + ) + + num_gpu_blocks = self.cache_config.num_gpu_blocks + assert num_gpu_blocks is not None and num_gpu_blocks > 0 + + self.block_size = self.cache_config.block_size + + self.dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): The scheduler’s block_size must be multiplied + # by dcp_world_size, since block hashes are computed on the + # original full token sequence at a granularity of + # original_block_size × dcp_world_size. + if self.dcp_world_size > 1: + self.block_size *= self.dcp_world_size + + # req_id -> Request + self.requests: dict[str, Request] = {} + # Scheduling policy + if self.scheduler_config.policy == "priority": + self.policy = SchedulingPolicy.PRIORITY + elif self.scheduler_config.policy == "fcfs": + self.policy = SchedulingPolicy.FCFS + else: + raise ValueError( + f"Unknown scheduling policy: {self.scheduler_config.policy}") + # Priority queues for requests. + self.waiting = create_request_queue(self.policy) + self.running: list[Request] = [] + + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: set[str] = set() + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() + + # Encoder-related. + # Calculate encoder cache size if applicable + # NOTE: For now we use the same budget for both compute and space. + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + mm_registry=mm_registry, + ) + + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and + # projector if needed) for MM models as well as encoder-decoder + # transformers. + self.max_num_encoder_input_tokens = encoder_compute_budget + # NOTE: For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized because cache size is 0 + # for these models. + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) + + speculative_config = vllm_config.speculative_config + self.use_eagle = False + self.num_spec_tokens = self.num_lookahead_tokens = 0 + if speculative_config: + self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.use_eagle(): + self.use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + use_eagle=self.use_eagle, + log_stats=self.log_stats, + enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, + ) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + + def schedule(self) -> RecomputeSchedulerOutput: + """This scheduler extends vLLM's original v1 scheduler + by introducing a decoding instance recomputing scheduling strategy. + Specifically, if a request is preempted in the decoding instance, + it halts the process with the recomputed symbol and recalculates + its KVC in the prefill instance.""" + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + recomputed_reqs: list[RecomputeReqInfo] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - 1 - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_compute_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + transfer_config = self.vllm_config.kv_transfer_config + if transfer_config is not None and not transfer_config.is_kv_producer: + recomputed_req = self.running.pop() + self.kv_cache_manager.free(recomputed_req) + recomputed_reqs.append( + RecomputeReqInfo(recomputed_req.request_id, + recomputed_req.output_token_ids, + recomputed_req.client_index)) + if recomputed_req == request: + can_schedule = False + break + else: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. + if not preempted_reqs and not recomputed_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras)): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if not self.scheduler_config.chunked_prefill_enabled and \ + num_new_tokens > token_budget: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_compute_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = (0 if request.num_computed_tokens + == 0 else + self.num_lookahead_tokens) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens = \ + self.scheduler_config.max_num_encoder_input_tokens + else: + num_encoder_tokens = 0 + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + ) + + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + + scheduled_resumed_reqs) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(scheduled_requests, + scheduled_spec_decode_tokens)) + scheduler_output = RecomputeSchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + recomputed_reqs=recomputed_reqs, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + self._update_after_schedule(scheduler_output) + return scheduler_output + + def _update_after_schedule( + self, + scheduler_output: RecomputeSchedulerOutput, + ) -> None: + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + request = self.requests[req_id] + request.num_computed_tokens += num_scheduled_token + + # NOTE: _free_encoder_inputs relies on num_computed_tokens, which + # may be updated again in _update_from_output for speculative + # decoding. However, it is safe to call the method here because + # encoder inputs are always part of the prompt, not the output, + # and thus are unaffected by speculative decoding. + if request.has_encoder_inputs: + self._free_encoder_inputs(request) + + # Clear the finished request IDs. + # NOTE: We shouldn't do self.finished_req_ids.clear() here because + # it will also affect the scheduler output. + self.finished_req_ids = set() + + def _make_cached_request_data( + self, + running_reqs: list[Request], + resumed_reqs: list[Request], + num_scheduled_tokens: dict[str, int], + spec_decode_tokens: dict[str, list[int]], + req_to_new_blocks: dict[str, KVCacheBlocks], + ) -> CachedRequestData: + req_ids: list[str] = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + num_computed_tokens: list[int] = [] + + use_connector = self.connector is not None + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id + req_ids.append(req_id) + num_tokens = (num_scheduled_tokens[req_id] - + len(spec_decode_tokens.get(req_id, ()))) + if self.use_pp: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. Otherwise, we don't + # need to send the sampled tokens back because the model runner + # will cache them. + token_ids = req.all_token_ids[req.num_computed_tokens:req. + num_computed_tokens + num_tokens] + new_token_ids.append(token_ids) + elif use_connector: + # When using a KVConnector, we add a placeholder to avoid index + # out of bounds errors. TODO: Remove this once the KVConnector + # is updated to handle token IDs properly. + new_token_ids.append([]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + num_computed_tokens.append(req.num_computed_tokens) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + + return CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + def _try_schedule_encoder_inputs( + self, + request: Request, + num_computed_tokens: int, + num_new_tokens: int, + encoder_compute_budget: int, + ) -> tuple[list[int], int, int]: + """ + Determine which encoder inputs need to be scheduled in the current step, + and update `num_new_tokens` and encoder token budget accordingly. + + An encoder input will be scheduled if: + - Its output tokens overlap with the range of tokens being computed + in this step, i.e., + [num_computed_tokens, num_computed_tokens + num_new_tokens). + - It is not already computed and stored in the encoder cache. + - There is sufficient encoder token budget to process it. + - The encoder cache has space to store it. + + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. + + Note that num_computed_tokens includes both locally cached + blocks and externally cached blocks (via KVConnector). + """ + if num_new_tokens == 0 or not request.has_encoder_inputs: + return [], num_new_tokens, encoder_compute_budget + encoder_inputs_to_schedule: list[int] = [] + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_new_tokens: + # The encoder input is not needed in this step. + break + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used.") + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder input is already computed and stored + # in the decoder's KV cache. + continue + + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_features[i].identifier in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue + + if self.encoder_cache_manager.check_and_update_cache( + request, i): + # The encoder input is already computed and cached from a + # previous step. + continue + + # If no encoder input chunking is allowed, we do not want to + # partially schedule a multimodal item. If the scheduled range would + # only cover part of the mm input, roll back to before the mm item. + if (self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens)): + num_new_tokens = start_pos - num_computed_tokens + break + + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, + num_tokens_to_schedule): + # The encoder cache is full or the encoder budget is exhausted. + # NOTE(woosuk): We assume that the encoder input tokens should + # be processed altogether, as the encoder usually uses + # bidirectional attention. + if num_computed_tokens < start_pos: + # We only schedule the decoder tokens just before the + # encoder input. + num_new_tokens = start_pos - num_computed_tokens + else: + # Because of prefix caching, num_computed_tokens is greater + # than start_pos even though its encoder input is not + # available. In this case, we can't schedule any token for + # the request in this step. + num_new_tokens = 0 + break + + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_features[i].identifier) + encoder_inputs_to_schedule.append(i) + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) + + def get_grammar_bitmask( + self, + requests: list[Request], + scheduled_spec_decode_tokens: dict[str, list[int]], + ): + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to its index in the batch. + # This will help us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + for i, req in enumerate(requests): + if req.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[req.request_id] = i + + if not structured_output_request_ids: + bitmask = None + else: + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + return structured_output_request_ids, bitmask + + def update_from_output( + self, + scheduler_output: RecomputeSchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output + + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None + kv_connector_stats = (kv_connector_output.kv_connector_stats + if kv_connector_output else None) + # return recomputed requests as EngineCoreOutput + for req_info in scheduler_output.recomputed_reqs: + outputs[req_info.client_index].append( + EngineCoreOutput( + request_id=req_info.request_id, + finish_reason=FinishReason.STOP, + new_token_ids=[req_info.output_token_ids[-1]], + stop_reason="recomputed", + )) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens. + request.num_computed_tokens -= num_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + status_before_stop = request.status + + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output( + request, new_token_ids) + + # Stop checking for pooler models. + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + stopped = check_stop(request, self.max_model_len, + pooler_output) + + if stopped: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None \ + and request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance( + request): + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # checked above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ + or kv_transfer_params: + + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + )) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = remove_all(self.running, stopped_running_reqs) + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if model_runner_output.kv_connector_output: + self._update_from_kv_xfer_finished( + model_runner_output.kv_connector_output) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if (stats := self.make_stats(spec_decoding_stats, + kv_connector_stats)) is not None: + # Return stats to only one of the front-ends. + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats + + return engine_core_outputs + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + stopped = False + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + del new_token_ids[num_new:] # Trim new tokens if needed. + break + return new_token_ids, stopped + + def _free_encoder_inputs(self, request: Request) -> None: + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if not cached_encoder_input_ids: + return + + # Here, we use list(set) to avoid modifying the set while iterating + # over it. + for input_id in list(cached_encoder_input_ids): + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + + def add_request(self, request: Request) -> None: + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + else: + request_ids = set(request_ids) + + running_requests_to_remove = set() + waiting_requests_to_remove = [] + valid_requests = [] + + # First pass: collect requests to remove from queues + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + valid_requests.append(request) + if request.status == RequestStatus.RUNNING: + running_requests_to_remove.add(request) + else: + waiting_requests_to_remove.append(request) + + # Remove all requests from queues at once for better efficiency + if running_requests_to_remove: + self.running = remove_all(self.running, running_requests_to_remove) + if waiting_requests_to_remove: + self.waiting.remove_requests(waiting_requests_to_remove) + + # Second pass: set status and free requests + for request in valid_requests: + request.status = finished_status + self._free_request(request) + + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) + request_id = request.request_id + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + self.kv_cache_manager.free(request) + del self.requests[request.request_id] + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: + return len(self.finished_req_ids) > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() + + def make_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats] = None, + kv_connector_stats: Optional[KVConnectorStats] = None, + ) -> Optional[SchedulerStats]: + if not self.log_stats: + return None + prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() + assert prefix_cache_stats is not None + return SchedulerStats(num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted + for req in self.running), + kv_connector_stats=kv_connector_stats.data + if kv_connector_stats else None) + + def make_spec_decoding_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats], + num_draft_tokens: int, + num_accepted_tokens: int, + ) -> Optional[SpecDecodingStats]: + if not self.log_stats: + return None + if spec_decoding_stats is None: + spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + spec_decoding_stats.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) + return spec_decoding_stats + + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() + + ######################################################################## + # KV Connector Related Methods + ######################################################################## + + def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + return self.connector + + def _connector_finished( + self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Invoke the KV connector request_finished() method if applicable. + + Returns optional kv transfer parameters to be included with the + request outputs. + """ + if self.connector is None: + return False, None + + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + return self.connector.request_finished(request, block_ids) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + def _update_from_kv_xfer_finished(self, + kv_connector_output: KVConnectorOutput): + """ + KV Connector: update the scheduler state based on the output. + + The Worker side connectors add finished_recving and + finished_sending reqs to the output. + * if finished_sending: free the blocks + # if finished_recving: add to state so we can + schedule the request during the next step. + """ + + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + + # KV Connector:: update recv and send status from last step. + for req_id in (kv_connector_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (kv_connector_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + if req_id not in self.requests: + logger.warning( + "Got finished sending KV transfer for request %s," + "but the request is already freed.", req_id) + else: + self._free_blocks(self.requests[req_id]) + + +@dataclass +class RecomputeReqInfo: + request_id: str + output_token_ids: ConstantList + client_index: int = 0 + + +@dataclass +class RecomputeSchedulerOutput: + + # list of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: list[NewRequestData] + # list of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: CachedRequestData + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. + scheduled_spec_decode_tokens: dict[str, list[int]] + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: dict[str, list[int]] + # Number of common prefix blocks for all requests in each KV cache group. + # This can be used for cascade attention. + num_common_prefix_blocks: list[int] + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: set[str] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # requests that need to recompute kv + recomputed_reqs: list[RecomputeReqInfo] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm_npu/core/schedule_config.py b/vllm_npu/core/schedule_config.py new file mode 100644 index 0000000..1fad95d --- /dev/null +++ b/vllm_npu/core/schedule_config.py @@ -0,0 +1,108 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from dataclasses import dataclass, fields +from typing import Type, Union + +from vllm.config import SchedulerConfig + +MAX_INT = 2147483647 + + +@dataclass +class AscendSchedulerConfig(SchedulerConfig): + enable_chunked_prefill: bool = False + max_long_partial_prefills: int = 1 + long_prefill_token_threshold: int = MAX_INT + policy: str = "fcfs" + scheduler_cls: Union[str, Type[object]] = ( + "vllm_npu.core.scheduler.AscendScheduler") + enable_pd_transfer: bool = False + decode_max_num_seqs: int = 0 + + @classmethod + def initialize_from_config( + cls, + vllm_scheduler_config: SchedulerConfig, + ascend_scheduler_config, + ): + scheduler_config = { + field.name: getattr(vllm_scheduler_config, field.name) + for field in fields(vllm_scheduler_config) if field.init + } + # Override default values into original SchedulerConfig + scheduler_config["enable_chunked_prefill"] = False + scheduler_config["max_long_partial_prefills"] = None + scheduler_config["long_prefill_token_threshold"] = None + scheduler_config["policy"] = "fcfs" + scheduler_config["scheduler_cls"] = ( + "vllm_npu.core.scheduler.AscendScheduler") + scheduler_config["enable_pd_transfer"] = False + scheduler_config["decode_max_num_seqs"] = 0 + # Override params in original SchedulerConfig with params in ascend_scheduler_config + for k, _ in scheduler_config.items(): + if hasattr(ascend_scheduler_config, k): + scheduler_config[k] = getattr(ascend_scheduler_config, k) + return cls(**scheduler_config) + + def __post_init__(self) -> None: + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + self.chunked_prefill_enabled = self.enable_chunked_prefill + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + "Ascend scheduler is enabled without chunked prefill feature. " + f"Argument max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + # concurrent partial prefills. Default is 1 meaning not enabled. + if self.max_long_partial_prefills is None: + self.max_long_partial_prefills = 1 + self.long_prefill_token_threshold = MAX_INT + + if self.long_prefill_token_threshold is None or \ + self.long_prefill_token_threshold <= 0: + if self.max_model_len is None: + self.long_prefill_token_threshold = MAX_INT + else: + self.long_prefill_token_threshold = \ + max(1, int(self.max_model_len * 0.04)) + + if self.max_long_partial_prefills < 0: + raise ValueError( + f"max_long_partial_prefills must be non-negative, but got " + f"{self.max_long_partial_prefills}") + if self.long_prefill_token_threshold < 0: + raise ValueError( + f"long_prefill_token_threshold must be non-negative, but got " + f"{self.long_prefill_token_threshold}") + + if self.policy != "fcfs": + raise NotImplementedError( + f"currently AscendScheduler only supports fcfs policy, got {self.policy}" + ) + if self.send_delta_data: + raise NotImplementedError( + "currently AscendScheduler doesn't support send_delta_data.") + if getattr(self, "scheduler_delay_factor", 0) > 0: + raise NotImplementedError( + "currently AscendScheduler doesn't support scheduler_delay_factor." + ) diff --git a/vllm_npu/core/scheduler.py b/vllm_npu/core/scheduler.py new file mode 100644 index 0000000..f4c8cc7 --- /dev/null +++ b/vllm_npu/core/scheduler.py @@ -0,0 +1,587 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import time +from collections import deque +from typing import Iterable, Union + +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVEventBatch +from vllm.logger import logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.utils import cdiv +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + + +class AscendScheduler(Scheduler): + """This Scheduler extends vllm's original v1 scheduler + with prefill-first scheduling strategy.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__(vllm_config, kv_cache_config, + structured_output_manager, mm_registry, + include_finished_set, log_stats) + self.scheduled_req_ids: set[str] = set() + self.running: list[Request] = [] + + self.finished_prefill_reqs: deque[Request] = deque() + enable_pd_transfer = getattr(self.scheduler_config, + 'enable_pd_transfer', False) + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.phase = "" if not enable_pd_transfer else "prefill" + self.decode_max_num_running_reqs = max(self.max_num_running_reqs, + decode_max_num_seqs) + + def schedule(self) -> SchedulerOutput: + if self.scheduler_config.chunked_prefill_enabled: + return super().schedule() + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # Record scheduled LoRA requests. + scheduled_loras: set[int] = set() + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later + skipped_waiting_requests: deque[Request] = deque() + + if self.phase == "prefill": + remaining_running_reqs = [] + for request in self.running: + # move request has finished prefill to finished_prefill_reqs + if request.num_tokens > request.num_prompt_tokens: + self.finished_prefill_reqs.append(request) + else: + remaining_running_reqs.append(request) + self.running = remaining_running_reqs + # all request prefilled, change phase to decode + if not self.waiting and not self.running: + self.phase = "decode" + # Skip long prompt requests in prefill stage. + # long_prefill_budget is float('inf') if not use. + if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0: + long_prefill_budget = float('inf') + long_prefill_token_threshold = float('inf') + else: + long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills + long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold + + # Schedule prefill requests first. + while self.waiting and token_budget > 0: + if len(self.running) == (self.decode_max_num_running_reqs + if self.phase == "decode" else + self.max_num_running_reqs): + + break + + request = self.waiting[0] + + def skip_cur_request(): + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + + # P/D: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + skip_cur_request() + continue + + # Check that adding the request still respects the max_loras + # constraint. + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras)): + # Scheduling would exceed max_loras, skip. + skip_cur_request() + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + else: + # P/D: skip checking prefix cache if loaded from remote kvs. + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + + # P/D: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + blocks = None + # Number of tokens to be scheduled. + else: + prompt_limit = self._get_prompt_limit(request) + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + max_tokens_in_kvcache = (self.kv_cache_config.num_blocks * + self.block_size) + prompt_limit = min(prompt_limit, max_tokens_in_kvcache) + + # Finish request that exceeds prompt_limit or kv cache size. + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) + request.status = RequestStatus.FINISHED_IGNORED + self.finished_req_ids.add( # type: ignore + request.request_id) # type: ignore + self.waiting.popleft() + continue + + if num_new_tokens > token_budget: + # Scheduling would exceed token_budget, skip. + skip_cur_request() + continue + assert num_new_tokens > 0 + blocks = new_computed_blocks.blocks[0] + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0 or len( + encoder_inputs_to_schedule) == 0: + # The request cannot be scheduled. + break + + watermark = getattr(self.scheduler_config, "watermark", 0.01) + if not self._check_watermark_for_prefill(request, num_new_tokens, + blocks, watermark): + # Scheduling would exceed watermark, skip. + skip_cur_request() + continue + + if num_new_tokens > long_prefill_token_threshold \ + and long_prefill_budget <= 0: + skip_cur_request() + continue + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async) + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + self.waiting.popleft() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.appendleft(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + self.scheduled_req_ids.add(request.request_id) + # Check request status. + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + + req_to_new_blocks[ + request.request_id] = self.kv_cache_manager.get_blocks( + request.request_id) + # Update request info. + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + if num_new_tokens > long_prefill_token_threshold: + long_prefill_budget -= 1 + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + if self.phase == "decode": + while len( + self.running + ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: + request = self.finished_prefill_reqs.popleft() + self.running.append(request) + + # If no prefill requests are scheduled, + # Schedule decode requests next. + if len(self.scheduled_req_ids) == 0: + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + if request.request_id in self.scheduled_req_ids: + # This request has already been scheduled. + req_index += 1 + continue + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + assert (request.num_tokens - request.num_computed_tokens) == 1 + num_new_tokens = min(num_new_tokens, token_budget) + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id + not in scheduled_loras): + # Scheduling would exceed max_loras, skip. + num_new_tokens = 0 + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reason: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. Adding the request exceeds the max_loras constraint. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + self.scheduled_req_ids.add(request.request_id) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + + # Record scheduled LoRA requests. + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len( + self.running + ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs) <= len(self.running) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, scheduled_resumed_reqs, + num_scheduled_tokens, scheduled_spec_decode_tokens, + req_to_new_blocks) + scheduled_cached_reqs = cached_reqs_data + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=scheduled_cached_reqs, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() # type: ignore + return scheduler_output + + def _check_watermark_for_prefill(self, + request, + num_new_tokens, + computed_blocks, + watermark=0.01): + computed_blocks = computed_blocks or [] + watermark_blocks = self.kv_cache_config.num_blocks * watermark + num_computed_tokens = (request.num_computed_tokens + + len(computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_new_tokens + num_computed_tokens, + self.block_size) + req_blocks = self.kv_cache_manager.coordinator.get_blocks( + request.request_id) + num_new_blocks = (num_required_blocks - len(req_blocks[0]) - + len(computed_blocks)) + num_evictable_computed_blocks = sum(1 for blk in computed_blocks + if blk.ref_cnt == 0) + # If number of free blocks is less than water mark after allocating, don't allocate. + if (self.kv_cache_manager.block_pool.get_num_free_blocks() - + num_evictable_computed_blocks - + num_new_blocks) < watermark_blocks: + return False + return True + + def _get_prompt_limit(self, request: Request) -> int: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if request.lora_request and request.lora_request.long_lora_max_len: + assert prompt_limit <= request.lora_request.long_lora_max_len + return request.lora_request.long_lora_max_len + else: + return prompt_limit + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + if request.status == RequestStatus.RUNNING: + self.scheduled_req_ids.discard(request.request_id) + super().finish_requests(request_ids, finished_status) + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + continue + if req_id in self.scheduled_req_ids: + self.scheduled_req_ids.remove(req_id) + + return super().update_from_output(scheduler_output, + model_runner_output) diff --git a/vllm_npu/cpu_binding.py b/vllm_npu/cpu_binding.py new file mode 100644 index 0000000..280e516 --- /dev/null +++ b/vllm_npu/cpu_binding.py @@ -0,0 +1,330 @@ +import os +import subprocess +from dataclasses import dataclass +from itertools import accumulate +from typing import Dict, List, Optional, Tuple, Union + +import psutil +import torch_npu +from vllm.logger import logger + +ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES") +CPU_BINDING_NUM = os.getenv("CPU_BINDING_NUM") + + +def execute_command(cmd_list): + with subprocess.Popen(cmd_list, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) as p: + out, err = p.communicate(timeout=1000) + res = out.decode() + return res + + +@dataclass +class DeviceInfo: + """ + Parse a single line of device information into structured data. + """ + _info_line: str = "" + npu_id: int = 0 + chip_id: int = 0 + chip_logic_id: Union[int, str] = 0 + chip_name: str = "" + + def __post_init__(self): + npu_id_str, chip_id_str, chip_logic_id_str, self.chip_name = self._info_line.strip( + ).split(None, 3) + self.npu_id = int(npu_id_str) + self.chip_id = int(chip_id_str) + if chip_logic_id_str.isnumeric(): + self.chip_logic_id = int(chip_logic_id_str) + + +class NpuHbmInfo: + visible_npu_ids: Optional[List[int]] = None + hbm_capacity: Optional[int] = None + hbm_usage: Optional[int] = None + + @classmethod + def set_visible_devices(cls, world_size): + """ + Determine which NPUs are visible to the current process and cache their + logical NPU IDs in `cls.visible_npu_ids`. + """ + if cls.visible_npu_ids: + return + if ASCEND_RT_VISIBLE_DEVICES is None: + devices = sorted(list(_get_device_map_info().keys())) + else: + devices_str = ASCEND_RT_VISIBLE_DEVICES + devices = [int(x) for x in devices_str.split(",")] + device_map_info = _get_device_map_info() + npu_ids = [] + for device in devices: + device_info = device_map_info.get(device) + if device_info is None: + raise RuntimeError( + f"Device {device} not found in device_map_info") + npu_ids.append(device_info.npu_id) + cls.visible_npu_ids = npu_ids + + @classmethod + def get_hbm_capacity(cls, rank, world_size, need_nz): + """ + Query and cache the HBM (or DDR) capacity in **bytes** for the NPU assigned + to the current process. + """ + soc_version = torch_npu._C._npu_get_soc_version() + if cls.hbm_capacity: + return cls.hbm_capacity + if not cls.visible_npu_ids: + cls.set_visible_devices(world_size) + assert cls.visible_npu_ids is not None + npu_id = cls.visible_npu_ids[rank] + memory_info = execute_command( + ["npu-smi", "info", "-i", f"{npu_id}", "-t", + "memory"]).split("\n")[1:] + if soc_version == 240: + hbm_capacity_key = 'Capacity(MB)' + elif not need_nz: + hbm_capacity_key = 'HBM Capacity(MB)' + else: + hbm_capacity_key = 'DDR Capacity(MB)' + for line in memory_info: + try: + key, value = line.strip().split(':', 2) + if key.strip() == hbm_capacity_key: + cls.hbm_capacity = int(value.strip()) * 1024 * 1024 + return cls.hbm_capacity + except ValueError: + pass + raise ValueError('not found valid hbm capactiy') + + @classmethod + def get_hbm_usage(cls, rank, world_size, need_nz): + """ + Return the current HBM or DDR usage + ratio (0-1) for the NPU assigned to the given rank. + """ + if cls.hbm_usage: + return cls.hbm_usage + if not cls.visible_npu_ids: + cls.set_visible_devices(world_size) + assert cls.visible_npu_ids is not None + npu_id = cls.visible_npu_ids[rank] + usage_info = execute_command( + ["npu-smi", "info", "-i", f"{npu_id}", "-t", + "usages"]).split("\n")[1:] + soc_version = torch_npu._C._npu_get_soc_version() + if soc_version == 240: + hbm_capacity_key = 'Memory Usage Rate(%)' + elif not need_nz: + hbm_capacity_key = 'HBM Usage Rate(%)' + else: + hbm_capacity_key = 'DDR Usage Rate(%)' + for line in usage_info: + try: + key, value = line.strip().split(':', 2) + if key.strip() == hbm_capacity_key: + hbm_usage = (float(value.strip()) + 1) / 100 + return hbm_usage + except ValueError: + pass + raise ValueError('not found valid hbm usage') + + +def _get_device_map_info() -> Dict[int, DeviceInfo]: + """ + Build and return a mapping from logical chip ID (int) to its DeviceInfo object. + """ + device_map_info = {} + device_map = execute_command(["npu-smi", "info", + "-m"]).strip().split("\n")[1:] + for line in device_map: + device_info = DeviceInfo(line.strip()) + if isinstance(device_info.chip_logic_id, int): + device_map_info[device_info.chip_logic_id] = device_info + return device_map_info + + +def _get_pcie_info(devices: List[int], keyword="PCIeBusInfo"): + """ + Query each NPU in the given device list and return a mapping + from logical device ID to its PCIe bus address. + """ + device_map_info = _get_device_map_info() + device_pcie_tbl = {} + for device in devices: + device_info = device_map_info.get(device) + if not device_info: + raise RuntimeError( + "Can not get device info, you can use BIND_CPU=0 to skip.") + pcie_info = execute_command([ + "npu-smi", "info", "-t", "board", "-i", f"{device_info.npu_id}", + "-c", f"{device_info.chip_id}" + ]).strip().split("\n") + for _ in pcie_info: + line = ''.join(_.split()) + if line.startswith(keyword): + device_pcie_tbl[device] = line[len(keyword) + 1:] + break + + return device_pcie_tbl + + +def _get_numa_info(pcie_tbl, keyword="NUMAnode"): + """ + Build two mappings: device → NUMA node, and NUMA node → [devices]. + """ + device_numa_tbl: Dict[int, int] = {} # device id -> numa id + numa_devices_tbl: Dict[int, List[int]] = {} # numa id -> device ids + + for device, pcie_no in pcie_tbl.items(): + numa_info = execute_command(["lspci", "-s", f"{pcie_no}", + "-vvv"]).split("\n") + for _ in numa_info: + line = ''.join(_.split()) + if line.startswith(keyword): + numa_id = int(line[len(keyword) + 1:]) + device_numa_tbl[device] = numa_id + + devices = numa_devices_tbl.get(numa_id, None) + if devices is None: + numa_devices_tbl[numa_id] = list() + + numa_devices_tbl[numa_id].append(device) + break + + return device_numa_tbl, numa_devices_tbl + + +def _get_numa_info_v2( + devices: List[int], + keyword="NUMAnode(s)") -> Tuple[Dict[int, int], Dict[int, List[int]]]: + """ + Evenly distribute the given device list across all NUMA nodes and return + both device-to-numa and numa-to-devices mappings. + """ + numa_nodes = 1 + numa_info = execute_command(["lscpu"]).split("\n") + for _ in numa_info: + line = ''.join(_.split()) + if keyword not in line: + continue + numa_nodes = int(line[-1]) + break + + device_per_numa, tail_device = divmod(len(devices), numa_nodes) + device_count_per_numa_list = [ + device_per_numa + (i < tail_device) for i in range(numa_nodes) + ] + + ends = list(accumulate(device_count_per_numa_list)) + starts = [0] + ends[:-1] + + numa_devices_tbl = { + ind: devices[start:end] + for ind, (start, end) in enumerate(zip(starts, ends)) + } + + device_numa_tbl = { + device: numa + for numa, _devices in numa_devices_tbl.items() + for device in _devices + } + + return device_numa_tbl, numa_devices_tbl + + +def _get_cpu_info(numa_ids, keyword1="NUMAnode", keyword2="CPU(s)"): + """ + Parse lscpu output to build a dict that maps each NUMA + node ID to the list of CPU core IDs belonging to it. + """ + cpu_idx_tbl = dict() + numa_keywords = [keyword1 + str(idx) + keyword2 for idx in numa_ids] + cpu_info = execute_command(["lscpu"]).split("\n") + for _ in cpu_info: + line = ''.join(_.split()) + if any(line.startswith(word) for word in numa_keywords): + split_info = line.split(":") + cpu_id_ranges = split_info[-1].split(",") + + ranges = list() + for range_str in cpu_id_ranges: + endpoints = range_str.split("-") + if len(endpoints) != 2: + raise Exception( + "lscpu command output error, please check !") + + ranges += [ + cid for cid in range(int(endpoints[0]), + int(endpoints[1]) + 1) + ] + + numa_id = int(split_info[0].replace(keyword1, + '').replace(keyword2, '')) + cpu_idx_tbl[numa_id] = ranges + return cpu_idx_tbl + + +def bind_cpus(rank_id, ratio=0.5): + # get all visible devices + visible_devices = ASCEND_RT_VISIBLE_DEVICES + + if visible_devices is None: + devices = sorted(list(_get_device_map_info().keys())) + else: + devices = [int(x) for x in visible_devices.split(",")] + + # Query the NUMA affinity of each NPU via its PCIe address; if this fails, + # fall back to evenly distributing the devices across NUMA nodes. + device_pcie_tbl = _get_pcie_info(devices) + device_numa_tbl, numa_devices_tbl = _get_numa_info(device_pcie_tbl) + if not device_numa_tbl or not numa_devices_tbl: + device_numa_tbl, numa_devices_tbl = _get_numa_info_v2(devices) + + # Obtain the complete list of CPU cores for each NUMA node. + cpu_idx_tbl = _get_cpu_info(list(numa_devices_tbl.keys())) + + cur_device = devices[rank_id] + numa_id = device_numa_tbl.get(cur_device) + + # Within the NUMA node, evenly partition the CPU cores + # among all NPUs (or use the amount specified by CPU_BINDING_NUM) + shard_devices = numa_devices_tbl.get(numa_id) + shard_devices.sort() + + all_cpus = cpu_idx_tbl.get(numa_id) + logger.info( + f"rank_id: {rank_id}, device_id: {cur_device}, " + f"numa_id: {numa_id}, shard_devices: {shard_devices}, cpus: {all_cpus}" + ) + + cpu_nums = len(all_cpus) + if CPU_BINDING_NUM is None: + cpu_num_per_device = int(cpu_nums * ratio // len(shard_devices)) + else: + cpu_num_per_device = int(CPU_BINDING_NUM) + if len(shard_devices) * cpu_num_per_device > cpu_nums: + raise RuntimeError( + f"Cpu num in numa {numa_id} to assign {cpu_num_per_device} for every device is not enough, " + f"please decrease the value of CPU_BINDING_NUM!") + if cpu_num_per_device < 0: + raise ValueError("CPU_BINDING_NUM should not be less than 0.") + + idx = shard_devices.index(cur_device) + binding_cpus = [ + all_cpus[_] for _ in range(idx * cpu_num_per_device, (idx + 1) * + cpu_num_per_device) + ] + + # cpu bind + p = psutil.Process() + p.cpu_affinity(binding_cpus) + new_affinity = p.cpu_affinity() + logger.info( + f"process {p.pid}, new_affinity is {new_affinity}, cpu count {cpu_num_per_device}" + ) diff --git a/vllm_npu/cuda_compat.py b/vllm_npu/cuda_compat.py deleted file mode 100644 index 8aea3a2..0000000 --- a/vllm_npu/cuda_compat.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -CUDA-to-NPU Compatibility Layer. - -Monkey-patches ``torch.cuda`` APIs so that code written for CUDA -(e.g. ``torch.cuda.Stream()``, ``torch.cuda.Event()``) transparently -delegates to the corresponding ``torch.npu`` equivalents. This -allows vLLM's ``GPUModelRunner`` to run on Ascend NPU without -source modifications. -""" - -import types -from unittest.mock import MagicMock - -import torch - - -def _patch_cuda_to_npu() -> None: - """Apply monkey-patches: redirect torch.cuda → torch.npu.""" - import torch_npu # noqa: F401 - - # ------------------------------------------------------------------ - # Stream / Event - # ------------------------------------------------------------------ - torch.cuda.Stream = torch.npu.Stream # type: ignore[attr-defined] - torch.cuda.Event = torch.npu.Event # type: ignore[attr-defined] - torch.cuda.current_stream = torch.npu.current_stream # type: ignore - torch.cuda.default_stream = torch.npu.default_stream # type: ignore - - # torch.cuda.stream() context manager - torch.cuda.stream = torch.npu.stream # type: ignore[attr-defined] - - # ------------------------------------------------------------------ - # Device management - # ------------------------------------------------------------------ - torch.cuda.set_device = torch.npu.set_device # type: ignore - torch.cuda.synchronize = torch.npu.synchronize # type: ignore - torch.cuda.device_count = torch.npu.device_count # type: ignore - torch.cuda.current_device = torch.npu.current_device # type: ignore - torch.cuda.is_available = lambda: True # type: ignore - - # ------------------------------------------------------------------ - # Memory management - # ------------------------------------------------------------------ - torch.cuda.empty_cache = torch.npu.empty_cache # type: ignore - torch.cuda.mem_get_info = torch.npu.mem_get_info # type: ignore - torch.cuda.memory_allocated = torch.npu.memory_allocated # type: ignore - torch.cuda.max_memory_allocated = torch.npu.max_memory_allocated # type: ignore - torch.cuda.memory_reserved = torch.npu.memory_reserved # type: ignore - torch.cuda.max_memory_reserved = torch.npu.max_memory_reserved # type: ignore - torch.cuda.reset_peak_memory_stats = torch.npu.reset_peak_memory_stats # type: ignore - torch.cuda.memory_stats = torch.npu.memory_stats # type: ignore - - # ------------------------------------------------------------------ - # Device properties - # ------------------------------------------------------------------ - _real_npu_props = torch.npu.get_device_properties - - def _get_device_properties(device=None): - """Return NPU device properties with CUDA-compatible attributes.""" - props = _real_npu_props(device) - # GPUModelRunner accesses .multi_processor_count which may not - # exist on NPU. Provide a sensible fallback. - if not hasattr(props, "multi_processor_count"): - props.multi_processor_count = 1 # type: ignore[attr-defined] - if not hasattr(props, "major"): - props.major = 9 # type: ignore[attr-defined] - props.minor = 0 # type: ignore[attr-defined] - return props - - torch.cuda.get_device_properties = _get_device_properties # type: ignore - - # ------------------------------------------------------------------ - # Misc - # ------------------------------------------------------------------ - if not hasattr(torch.cuda, "_get_device_index"): - torch.cuda._get_device_index = torch.npu._get_device_index # type: ignore - - # graph / CUDAGraph stubs (NPU does not support CUDA graphs) - if not hasattr(torch.cuda, "CUDAGraph") or True: - torch.cuda.CUDAGraph = MagicMock # type: ignore[attr-defined] - - if not hasattr(torch.cuda, "graph"): - - def _noop_graph(*args, **kwargs): - """No-op context manager for CUDA graphs on NPU.""" - import contextlib - return contextlib.nullcontext() - - torch.cuda.graph = _noop_graph # type: ignore[attr-defined] diff --git a/vllm_npu/device_allocator/__init__.py b/vllm_npu/device_allocator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/device_allocator/camem.py b/vllm_npu/device_allocator/camem.py new file mode 100644 index 0000000..4efa617 --- /dev/null +++ b/vllm_npu/device_allocator/camem.py @@ -0,0 +1,278 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# CANN-mem-based pytorch pluggable allocator to implement sleep mode. +# +import dataclasses +import os +from contextlib import contextmanager +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from acl.rt import memcpy # type: ignore # noqa: F401 +from vllm.logger import logger + +from vllm_npu.platform import NPUPlatform + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +camem_available = False +try: + from vllm_npu.vllm_npu_C import ( # type: ignore # noqa: F401 + init_module, python_create_and_map, python_unmap_and_release) + lib_name = find_loaded_library("vllm_npu_C") + camem_available = True +except ImportError as e: + logger.warning( + "Failed to import vllm_npu_C:%s. Sleep mode will be disabled. ", e) + init_module = None + python_create_and_map = None + python_unmap_and_release = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = Tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]] +) -> torch.npu.memory.NPUPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc', + 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]]): + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.npu.memory.MemPool(new_alloc._allocator) + with torch.npu.memory.use_mem_pool(mem_pool): + yield mem_pool, new_alloc + + +class CaMemAllocator: + """ + A singleton class that manages a memory pool for CANN tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CaMemAllocator": + """ + CaMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + if CaMemAllocator.instance is None: + CaMemAllocator.instance = CaMemAllocator() + return CaMemAllocator.instance + + def __init__(self): + conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, \ + ("Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates.") + + self.pointer_to_data: Dict[int, AllocationData] = {} + self.current_tag: str = CaMemAllocator.default_tag + self.allocator_and_pools: Dict[str, Any] = {} + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) + return + + def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep( + self, + offload_tags: Optional[Union[Tuple[str, ...], + str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CaMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + + assert isinstance(offload_tags, tuple) + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=NPUPlatform.is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + ACL_MEMCPY_DEVICE_TO_HOST = 2 + dest_max = cpu_ptr + size_in_bytes * 2 + memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, + ACL_MEMCPY_DEVICE_TO_HOST) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory.""" + for ptr, data in self.pointer_to_data.items(): + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + ACL_MEMCPY_HOST_TO_DEVICE = 1 + dest_max = ptr + size_in_bytes * 2 + memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, + ACL_MEMCPY_HOST_TO_DEVICE) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CaMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm_npu/distributed/__init__.py b/vllm_npu/distributed/__init__.py index 61712bf..f36187f 100644 --- a/vllm_npu/distributed/__init__.py +++ b/vllm_npu/distributed/__init__.py @@ -1 +1,40 @@ -"""Ascend NPU distributed communication (HCCL).""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory + + +def register_connector(): + KVConnectorFactory.register_connector( + "LLMDataDistCMgrConnector", + "vllm_npu.distributed.llmdatadist_c_mgr_connector", + "LLMDataDistCMgrConnector") + + KVConnectorFactory.register_connector( + "MooncakeConnectorV1", "vllm_npu.distributed.mooncake_connector", + "MooncakeConnector") + + KVConnectorFactory.register_connector( + "MooncakeConnectorStoreV1", + "vllm_npu.distributed.mooncake.mooncake_store_connector_v1", + "MooncakeConnectorV1") + + KVConnectorFactory.register_connector( + "MooncakeLayerwiseConnector", + "vllm_npu.distributed.mooncake_layerwise_connector", + "MooncakeLayerwiseConnector") diff --git a/vllm_npu/distributed/communicator.py b/vllm_npu/distributed/communicator.py index 627f327..7c14bef 100644 --- a/vllm_npu/distributed/communicator.py +++ b/vllm_npu/distributed/communicator.py @@ -1,42 +1,46 @@ -""" -NPUCommunicator — HCCL-based device communicator for Ascend NPU. - -Extends ``DeviceCommunicatorBase`` with NPU-specific collective -operations using the HCCL backend. -""" - +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# from typing import List, Optional import torch import torch.distributed as dist -from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase, -) +from vllm.distributed.device_communicators.base_device_communicator import \ + DeviceCommunicatorBase class NPUCommunicator(DeviceCommunicatorBase): - """Device communicator for Ascend NPU using HCCL.""" - def __init__( - self, - cpu_group: dist.ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[dist.ProcessGroup] = None, - unique_name: str = "", - ): + def __init__(self, + cpu_group: dist.ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[dist.ProcessGroup] = None, + unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) - import torch_npu # noqa: F401 + # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator + # init device according to rank self.device = torch.npu.current_device() - def all_to_all( - self, - input_: torch.Tensor, - scatter_dim: int = 0, - gather_dim: int = -1, - scatter_sizes: Optional[List[int]] = None, - gather_sizes: Optional[List[int]] = None, - ) -> torch.Tensor: - """All-to-all communication for NPU tensors.""" + def all_to_all(self, + input_: torch.Tensor, + scatter_dim: int = 0, + gather_dim: int = -1, + scatter_sizes: Optional[List[int]] = None, + gather_sizes: Optional[List[int]] = None) -> torch.Tensor: + if scatter_dim < 0: scatter_dim += input_.dim() if gather_dim < 0: @@ -53,22 +57,17 @@ class NPUCommunicator(DeviceCommunicatorBase): tensor_shape = list(tensor_shape_base) tensor_shape[gather_dim] = gather_sizes[i] output_list.append( - torch.empty( - tensor_shape, - dtype=input_.dtype, - device=input_.device, - ) - ) + torch.empty(tensor_shape, + dtype=input_.dtype, + device=input_.device)) + else: input_list = [ - t.contiguous() - for t in torch.tensor_split( - input_, self.world_size, scatter_dim - ) + t.contiguous() for t in torch.tensor_split( + input_, self.world_size, scatter_dim) ] output_list = [ - torch.empty_like(input_list[i]) - for i in range(self.world_size) + torch.empty_like(input_list[i]) for i in range(self.world_size) ] dist.all_to_all(output_list, input_list, group=self.device_group) diff --git a/vllm_npu/distributed/cpu_offload_connector.py b/vllm_npu/distributed/cpu_offload_connector.py new file mode 100644 index 0000000..f77293c --- /dev/null +++ b/vllm_npu/distributed/cpu_offload_connector.py @@ -0,0 +1,471 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import queue +import threading +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Sequence + +import torch +from vllm.attention import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.utils import logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + MLAAttentionSpec) + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.distributed.cpu_offload_manager.metadata import ( + MetadataServer, MetadataServerProc, MLAConfig) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + + +@dataclass +class ReqMeta: + gpu_block_ids: list[int] + cpu_block_ids: list[int] + num_scheduled_tokens: int + num_computed_tokens: int + num_gpu_computed_tokens: int + num_cpu_computed_tokens: int + + def update(self, other: "ReqMeta"): + self.gpu_block_ids.extend(other.gpu_block_ids) + self.cpu_block_ids.extend(other.cpu_block_ids) + self.num_scheduled_tokens = other.num_scheduled_tokens + self.num_computed_tokens = other.num_computed_tokens + self.num_gpu_computed_tokens = other.num_gpu_computed_tokens + self.num_cpu_computed_tokens = other.num_cpu_computed_tokens + + +@dataclass +class CPUOffloadingConnectorMetadata(KVConnectorMetadata): + requests: dict[str, ReqMeta] + finished_req_ids: set[str] + + +class CPUOffloadingConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + if not vllm_config.cache_config.enable_prefix_caching: + self.connector_scheduler: Optional[ + CPUOffloadingConnectorScheduler] = None + self.connector_worker: Optional[ + CPUOffloadingConnectorWorker] = None + elif role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = CPUOffloadingConnectorScheduler( + vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = CPUOffloadingConnectorWorker(vllm_config) + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + if self.connector_worker is not None: + assert isinstance(connector_metadata, + CPUOffloadingConnectorMetadata) + self.connector_worker.bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + assert self.connector_worker is not None + self.connector_worker.clear_connector_metadata() + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + if self.connector_worker is not None: + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + if self.connector_worker is not None: + self.connector_worker.start_load_kv() + + def wait_for_layer_load(self, layer_name: str) -> None: + if self.connector_worker is not None: + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + + def wait_for_save(self): + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(), None + + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if self.connector_scheduler is not None: + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if self.connector_scheduler is not None: + return self.connector_scheduler.update_state_after_alloc(request) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + if self.connector_scheduler is not None: + return self.connector_scheduler.build_connector_meta( + scheduler_output) + return KVConnectorMetadata() + + def request_finished( + self, request: "Request", + block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]: + if self.connector_scheduler is not None: + self.connector_scheduler.request_finished(request) + return True, None + + +class CPUOffloadingConnectorScheduler: + + def __init__(self, vllm_config: VllmConfig): + logger.info("init CPUOffloadingConnectorScheduler") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.use_mla = vllm_config.model_config.use_mla + self.num_gpu_computed_tokens: dict[str, int] = {} + self.num_cpu_computed_tokens: dict[str, int] = {} + self.allocated_req_ids: set[str] = set() + self.finished_req_ids: list[str] = [] + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.zmq_rpc_client.call("post_init") + if vllm_config.kv_transfer_config is not None: + self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config( + "swap_in_threshold", 0) + else: + self.swap_in_threshold = 0 + logger.info(f"swap_in_threshold: {self.swap_in_threshold}") + + def get_num_new_matched_tokens( + self, ori_request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + request = copy.deepcopy(ori_request) + request.get_hash_new_full_blocks = None + num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call( + "get_matched_num_and_touch", request) + self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens + self.num_cpu_computed_tokens[ + request.request_id] = num_cpu_computed_tokens + if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold: + return num_cpu_computed_tokens - num_computed_tokens, load_async + else: + return 0, load_async + + def update_state_after_alloc(self, request: "Request"): + self.allocated_req_ids.add(request.request_id) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + num_tokens = {} + # process scheduled_new_reqs + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + num_tokens[req_id] = ( + req.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # process scheduled_cached_reqs + cached_reqs = scheduler_output.scheduled_cached_reqs + for idx, req_id in enumerate(cached_reqs.req_ids): + num_tokens[req_id] = ( + cached_reqs.num_computed_tokens[idx] + + scheduler_output.num_scheduled_tokens[req_id]) + + unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() - + self.allocated_req_ids - + scheduler_output.num_scheduled_tokens.keys()) + new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", + num_tokens, + unallocated_req_ids) + metadata = CPUOffloadingConnectorMetadata( + requests={}, + finished_req_ids=set(self.finished_req_ids), + ) + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + gpu_block_ids = req.block_ids[0] + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=req.num_computed_tokens, + num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id], + num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id]) + + for idx, req_id in enumerate(cached_reqs.req_ids): + gpu_block_ids = cached_reqs.new_block_ids[idx] + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=cached_reqs.num_computed_tokens[idx], + num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx], + num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx]) + self.num_gpu_computed_tokens.clear() + self.num_cpu_computed_tokens.clear() + self.allocated_req_ids.clear() + self.finished_req_ids.clear() + return metadata + + def request_finished(self, ori_request: "Request"): + request = copy.deepcopy(ori_request) + request.get_hash_new_full_blocks = None + self.finished_req_ids.append(request.request_id) + # inform metadata server to record request, and free it after finish sending + self.zmq_rpc_client.call("record_request_cache_and_free_slots", + request) + + +class CPUOffloadingConnectorWorker: + + def __init__(self, vllm_config: VllmConfig): + logger.info("init CPUOffloadingConnectorWorker") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.pp_rank = get_pp_group().rank_in_group + self.tp_group = get_tp_group() + self.tp_rank = self.tp_group.rank_in_group + self.tp_world_size = self.tp_group.world_size + self.use_mla = vllm_config.model_config.use_mla + + self.requests: dict[str, ReqMeta] = {} + self.load_stream = torch.npu.Stream() + self.save_stream = torch.npu.Stream() + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.load_block_mapping: list[tuple[int, int]] = [] + self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue() + self.save_output_queue: queue.Queue[str] = queue.Queue() + self.save_thread = threading.Thread(target=self._save_listener) + self.save_thread.start() + self.done_sending_count: defaultdict[str, int] = defaultdict(int) + + # start metadata server to init cpu_kv_cache_manager and handle rpc requests + # all dp shared the same metadata server, only start the process on data_rank 0 + if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0: + config = VllmConfig() + config.cache_config = vllm_config.cache_config + config.parallel_config = vllm_config.parallel_config + config.kv_transfer_config = vllm_config.kv_transfer_config + self.init_metadata_server(config) + self._wait_for_metadata_process_start() + + def init_metadata_server(self, vllm_config: VllmConfig): + self.metadata_thread = threading.Thread( + target=MetadataServerProc.run_metadata_server, + args=(vllm_config, ), + ) + self.metadata_thread.daemon = True + self.metadata_thread.start() + + def _wait_for_metadata_process_start(self): + # TODO: wait for metadata server to start, add a rpc to check if ready + while True: + try: + if self.zmq_rpc_client.call("ready"): + break + except Exception as e: + logger.info(f"wait for metadata server to start, error: {e}") + time.sleep(1) + + def bind_connector_metadata( + self, connector_metadata: CPUOffloadingConnectorMetadata) -> None: + for req_id, req in connector_metadata.requests.items(): + if req_id in self.requests: + self.requests[req_id].update(req) + req = self.requests[req_id] + else: + self.requests[req_id] = req + for i in range(req.num_gpu_computed_tokens // self.block_size, + req.num_computed_tokens // self.block_size): + self.load_block_mapping.append( + (req.cpu_block_ids[i], req.gpu_block_ids[i])) + for req_id in connector_metadata.finished_req_ids: + if req_id in self.requests: + self.save_input_queue.put((req_id, self.requests[req_id])) + + def clear_connector_metadata(self) -> None: + self.load_block_mapping.clear() + + def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]): + self.gpu_kv_caches = kv_caches + model_config = self.vllm_config.model_config + mla_config: Optional[MLAConfig] = None + if model_config.use_mla: + mla_config = MLAConfig( + model_config.hf_text_config.kv_lora_rank, + model_config.hf_text_config.qk_rope_head_dim) + self.cpu_kv_caches = list( + self.zmq_rpc_client.call( + "init_cpu_kv_caches", + self.pp_rank, + self.tp_rank, + get_kv_cache_spec(self.vllm_config), + mla_config, + ).values()) + + def start_load_kv(self) -> None: + self.current_layer = 0 + self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values()) + self.load_kv_layer(0) + + def wait_for_layer_load(self) -> None: + # TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug. + self.load_stream.synchronize() + self.current_layer += 1 + self.load_kv_layer(self.current_layer) + + def load_kv_layer(self, layer: int): + if layer == len(self.gpu_kv_caches): + return + gpu_kv_caches = next(self.gpu_kv_caches_load_iter) + cpu_kv_caches = self.cpu_kv_caches[layer] + with torch.npu.stream(self.load_stream): + for cpu_block_id, gpu_block_id in self.load_block_mapping: + for gpu_layer_part, cpu_layer_part in zip( + gpu_kv_caches, cpu_kv_caches): + gpu_layer_part[gpu_block_id].copy_( + cpu_layer_part[cpu_block_id], non_blocking=True) + + def get_finished(self) -> set[str]: + done_sending: set[str] = set() + while True: + try: + id = self.save_output_queue.get_nowait() + except queue.Empty: + break + done_sending.add(id) + for id in done_sending: + del self.requests[id] + if self.tp_world_size == 1: + return done_sending + if self.tp_rank == 0: + for req_id in done_sending: + self.done_sending_count[req_id] += 1 + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.tp_world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + self.done_sending_count[req_id] += 1 + all_done_sending: set[str] = set() + for req_id in list(self.done_sending_count.keys()): + if self.done_sending_count[req_id] == self.tp_world_size: + del self.done_sending_count[req_id] + all_done_sending.add(req_id) + # release cpu_kv_cache after request sending finished + # to avoid rpc blocking, use thread to call rpc asynchronously + sending_finished_thread = threading.Thread( + target=self._sending_finished, args=(all_done_sending, )) + sending_finished_thread.daemon = True + sending_finished_thread.start() + + return all_done_sending + else: + self.tp_group.send_object(done_sending, dst=0) + return done_sending + + def _sending_finished(self, all_done_sending): + for req_id in all_done_sending: + logger.debug(f"call cache_and_free_slots for req_id: {req_id}") + self.zmq_rpc_client.call("cache_and_free_slots", req_id) + + def _save_listener(self): + save_block_mapping = [] + while True: + req_id, req = self.save_input_queue.get() + for i in range( + req.num_cpu_computed_tokens // self.block_size, + min((req.num_computed_tokens + req.num_scheduled_tokens) // + self.block_size, len(req.cpu_block_ids))): + save_block_mapping.append( + (req.gpu_block_ids[i], req.cpu_block_ids[i])) + with torch.npu.stream(self.save_stream): + # MLA: kv_layer is tuple[tensor, tensor] means (rope, nope). + # non-MLA: kv_layer is list[tensor], typically means [k, v]. + if self.use_mla: + start, step = self.tp_rank, self.tp_world_size + else: + start, step = 0, 1 + for i in range(start, len(save_block_mapping), step): + gpu_block_id, cpu_block_id = save_block_mapping[i] + for cpu_kv_caches, gpu_kv_caches in zip( + self.cpu_kv_caches, self.gpu_kv_caches.values()): + for cpu_layer_part, gpu_layer_part in zip( + cpu_kv_caches, gpu_kv_caches): + cpu_layer_part[cpu_block_id].copy_( + gpu_layer_part[gpu_block_id], + non_blocking=True) + self.save_stream.synchronize() + self.save_output_queue.put(req_id) + save_block_mapping.clear() + + +# Copied from vllm_npu/worker/model_runner_v1.py. +def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: + forward_ctx = vllm_config.compilation_config.static_forward_context + block_size = vllm_config.cache_config.block_size + use_mla = vllm_config.model_config.use_mla + ascend_config = get_ascend_config() + use_sfa = ascend_config.use_sfa + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in forward_ctx.items(): + if isinstance(attn_module, FusedMoE): + continue + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + if use_mla and not use_sfa: + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + cache_dtype_str=vllm_config.cache_config.cache_dtype) + else: + # TODO(cmq): This is a hack way to fix deepseek kvcache when + # using DSA. Fix the spec in vLLM is a finnal way. + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype) + + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + return kv_cache_spec diff --git a/vllm_npu/distributed/cpu_offload_manager/__init__.py b/vllm_npu/distributed/cpu_offload_manager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py b/vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py new file mode 100644 index 0000000..fd68189 --- /dev/null +++ b/vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py @@ -0,0 +1,202 @@ +import time +from collections import defaultdict +from typing import Optional + +from vllm.utils import logger, sha256 +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + PrefixCachingMetrics) +from vllm.v1.core.single_type_kv_cache_manager import \ + get_manager_for_kv_cache_spec +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request + + +class CPUCacheStats: + + def __init__(self, enable_prefix_caching: bool, log_stats: bool = False): + self.enable_prefix_caching = enable_prefix_caching + self.log_stats = log_stats + self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + self.cpu_prefix_cache_metrics = PrefixCachingMetrics() + self.time_sec = int(time.time()) + + def log(self): + current_time_sec = int(time.time()) + # Log the prefix cache hit rate every 10 seconds. + if current_time_sec - self.time_sec >= 10: + self.time_sec = current_time_sec + logger.info("CPU Prefix cache hit rate: %.1f%%", + self.cpu_prefix_cache_metrics.hit_rate * 100) + + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + """Get (and reset) the prefix cache stats. + Returns: + The current prefix caching stats, or None if logging is disabled. + """ + if not self.log_stats: + return None + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + + def update(self, num_tokens, num_computed_tokens): + # Note the function is called by scheduler + if self.log_stats and self.enable_prefix_caching: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += num_tokens + self.prefix_cache_stats.hits += num_computed_tokens + + def set_cache_stats(self, num_tokens, num_computed_tokens): + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.hits = num_computed_tokens + self.prefix_cache_stats.queries = num_tokens + self.prefix_cache_stats.requests = 1 + + +class CPUKVCacheManager: + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + num_cpu_blocks: int, + caching_hash_algo: str = "builtin", + use_eagle: bool = False, + enable_kv_cache_events: bool = False, + ) -> None: + self.block_size = kv_cache_spec.block_size + self.num_cpu_blocks = num_cpu_blocks + self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + self.use_eagle = use_eagle + self.block_pool = BlockPool(self.num_cpu_blocks, True, + enable_kv_cache_events) + self.single_type_manager = get_manager_for_kv_cache_spec( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + kv_cache_group_id=0, + ) + # Record kv block hashes, avoid redundant computation. + self.req_to_block_hashes: defaultdict[ + str, list[BlockHash]] = defaultdict(list) + # Record blocks touched in get_matched_num_and_touch(). + self.req_to_computed_blocks: defaultdict[ + str, list[KVCacheBlock]] = defaultdict(list) + # Record the request that failed to allocate. + self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool) + self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int) + self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, + log_stats=True) + # Record request that will be free after finish sending + self.req_to_free: defaultdict[str, Request] = defaultdict(Request) + + def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]: + # When the request requires prompt logprobs, we skip prefix caching. + if (request.sampling_params.prompt_logprobs is not None): + return 0, False + request_id = request.request_id + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request_id] + if not block_hashes: + block_hashes = request.block_hashes + self.req_to_block_hashes[request_id] = block_hashes + max_cache_hit_length = request.num_tokens - 1 + computed_blocks = self.single_type_manager.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=[0], + block_pool=self.block_pool, + kv_cache_spec=self.single_type_manager.kv_cache_spec, + use_eagle=self.use_eagle, + ) + num_computed_tokens = len(computed_blocks[0]) * self.block_size + self.req_to_computed_blocks[request_id] = computed_blocks[0] + # We should touch these blocks in the concurrent scenarios. + self.block_pool.touch(computed_blocks) + + # cup prefix cache status set and log + assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None + self.cpu_cache_stats.set_cache_stats(request.num_tokens, + num_computed_tokens) + self.cpu_cache_stats.cpu_prefix_cache_metrics.observe( + self.cpu_cache_stats.prefix_cache_stats) + self.cpu_cache_stats.log() + + return num_computed_tokens, False + + def _release_ahead_touch(self, request_id: str): + computed_blocks = self.req_to_computed_blocks[request_id] + if computed_blocks: + self.single_type_manager.block_pool.free_blocks( + reversed(computed_blocks)) + self.req_to_computed_blocks.pop(request_id, None) + + def allocate_slots(self, req_to_num_tokens: dict[str, int], + unallocated_req_ids: set[str]) -> dict[str, list[int]]: + for request_id in unallocated_req_ids: + self._free_slots(request_id) + req_to_new_blocks = {} + for request_id, num_tokens in req_to_num_tokens.items(): + if self.req_failed_to_allocate[request_id]: + continue + new_computed_blocks = self.req_to_computed_blocks[request_id] + num_blocks_to_allocate = ( + self.single_type_manager.get_num_blocks_to_allocate( + request_id=request_id, + num_tokens=num_tokens, + new_computed_blocks=new_computed_blocks, + )) + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + self._release_ahead_touch(request_id) + self.req_failed_to_allocate[request_id] = True + continue + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + self.single_type_manager.save_new_computed_blocks( + request_id, new_computed_blocks) + # Allocate new blocks but do not cache now. + new_blocks = self.single_type_manager.allocate_new_blocks( + request_id, num_tokens) + self.req_to_num_tokens[request_id] = num_tokens + # No need to release ref_cnt because we use officially. + self.req_to_computed_blocks.pop(request_id, None) + req_to_new_blocks[request_id] = [ + block.block_id for block in new_computed_blocks + new_blocks + ] + return req_to_new_blocks + + def record_request_cache_and_free_slots(self, request: Request): + logger.debug( + f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager" + ) + self.req_to_free[request.request_id] = request + + def cache_and_free_slots(self, request_id: str): + logger.debug( + f"Cache and free slots for request {request_id} in cpu_kv_cache_manager" + ) + if request_id not in self.req_to_free: + logger.Error( + f"request {request_id} not in req_to_free, maybe bug!") + return + request = self.req_to_free[request_id] + if not self.req_failed_to_allocate[request_id]: + self.single_type_manager.cache_blocks( + request, + self.req_to_num_tokens[request_id], + ) + self._free_slots(request_id) + logger.debug( + f"delete request {request_id} in cpu_kv_cache_manager req_to_free") + del self.req_to_free[request_id] + + def _free_slots(self, request_id: str): + # This function is designed to be reentrant. + self._release_ahead_touch(request_id) + self.single_type_manager.free(request_id) + self.req_to_block_hashes.pop(request_id, None) + self.req_to_computed_blocks.pop(request_id, None) + self.req_failed_to_allocate.pop(request_id, None) + self.req_to_num_tokens.pop(request_id, None) diff --git a/vllm_npu/distributed/cpu_offload_manager/metadata.py b/vllm_npu/distributed/cpu_offload_manager/metadata.py new file mode 100644 index 0000000..468d0f5 --- /dev/null +++ b/vllm_npu/distributed/cpu_offload_manager/metadata.py @@ -0,0 +1,269 @@ +import math +import os +import pickle +from dataclasses import dataclass +from multiprocessing.shared_memory import SharedMemory +from typing import Any, Callable, Optional + +import torch +import vllm.envs as envs +import zmq +from vllm.config import KVTransferConfig, VllmConfig +from vllm.utils import get_dtype_size, logger, make_zmq_socket +from vllm.v1.kv_cache_interface import AttentionSpec + +from vllm_npu.distributed.cpu_offload_manager.cpu_kv_cache_manager import \ + CPUKVCacheManager + + +@dataclass +class MLAConfig: + nope_dim: int + rope_dim: int + + +def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig: + if vllm_config.kv_transfer_config is not None: + kv_transfer_config = vllm_config.kv_transfer_config + if kv_transfer_config.kv_connector == "CPUOffloadingConnector": + return kv_transfer_config + elif kv_transfer_config.kv_connector == "MultiConnector": + ktcs = kv_transfer_config.kv_connector_extra_config.get( + "connectors") + for ktc in ktcs: + kv_transfer_config = KVTransferConfig(**ktc) + if kv_transfer_config.kv_connector == "CPUOffloadingConnector": + return kv_transfer_config + return None + + +class MetadataServer: + METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc" + DEFAULT_CPU_SWAP_SPACE_GB = 800 + + class ZMQRPCClient: + + def __init__(self, identity=f"worker-{os.getpid()}"): + logger.info(f"metadata client for worker {identity} started") + self.ctx = zmq.Context() # type: ignore + self.socket = make_zmq_socket( + self.ctx, + MetadataServer.METADATA_SERVER_ADDRESS, + zmq.DEALER, # type: ignore + bind=False, + identity=identity.encode(), + linger=0) + + def call(self, func_name: str, *args, **kwargs) -> Any: + request = (func_name, args, kwargs) + self.socket.send(b"", zmq.SNDMORE) # type: ignore + self.socket.send(pickle.dumps(request)) + _ = self.socket.recv() + response = pickle.loads(self.socket.recv()) + result, error = response + if error: + logger.exception(f"call metadata sever error: {error}") + raise error + if func_name == "init_cpu_kv_caches": + (memory_dict, layer_size, layer_dtype, mla_config) = result + # shared_memory_dict is recorded in self to close + self.shared_memory_dict = memory_dict + result = {} + for key, shm in memory_dict.items(): + tensor = torch.frombuffer( + shm.buf, dtype=layer_dtype).reshape(layer_size) + if mla_config is not None: + tensor = tensor.split( + [mla_config.nope_dim, mla_config.rope_dim], dim=-1) + result[key] = tensor + return result + + def __del__(self): + # will be finalized by outer process + self.socket.close() + self.ctx.term() + if hasattr(self, 'shared_memory_dict'): + for shm in self.shared_memory_dict.values(): + shm.close() + + def __init__(self, vllm_config: VllmConfig): + self.world_size = vllm_config.parallel_config.world_size + self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size + kv_transfer_config = get_cpu_offload_connector(vllm_config) + assert kv_transfer_config is not None + available_memory_gb = kv_transfer_config.get_from_extra_config( + "cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB) + self.available_memory = available_memory_gb * 1024 * 1024 * 1024 + logger.info(f"cpu swap space: {self.available_memory} bytes") + self.ctx = zmq.Context() # type: ignore + self.socket = make_zmq_socket( + self.ctx, + MetadataServer.METADATA_SERVER_ADDRESS, + zmq.ROUTER, # type: ignore + bind=True, + linger=0) + self.functions: dict[str, Callable] = { + "init_cpu_kv_caches": self.init_cpu_kv_caches, + "post_init": self.post_init, + "ready": self.ready, + } + self.shared_memory = {} # type: ignore + self.num_cpu_blocks = -1 + + @staticmethod + def _safe_create_shared_memory(name: str, size: int) -> SharedMemory: + try: + existing_shm = SharedMemory(name=name, create=False) + existing_shm.close() + existing_shm.unlink() + except FileNotFoundError: + pass + return SharedMemory(name=name, create=True, size=size) + + def ready(self): + return True + + def init_cpu_kv_caches( + self, + pp_rank: int, + tp_rank: int, + kv_cache_specs: dict[str, AttentionSpec], + mla_config: MLAConfig, + ) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, + MLAConfig]: + logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}") + # follow the assumption that each layer has the same spec + layer = next(iter(kv_cache_specs.values())) + assert all([ + layer.page_size_bytes == any.page_size_bytes + for any in kv_cache_specs.values() + ]) + # mla shares the same kv cache among different tp + if layer.use_mla: + tp_rank = 0 + if (pp_rank, tp_rank) in self.shared_memory: + return self.shared_memory[(pp_rank, tp_rank)] + available_memory = self.available_memory + shared_memory_dict = {} + if layer.use_mla: + available_memory //= self.pipeline_parallel_size + available_memory //= len(kv_cache_specs) + num_blocks = available_memory // layer.page_size_bytes + layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, + layer.head_size) # type: ignore + else: + available_memory //= self.world_size + available_memory //= len(kv_cache_specs) + num_blocks = available_memory // layer.page_size_bytes + layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, + layer.head_size) # type: ignore + nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype) + for layer_name in kv_cache_specs.keys(): + # only this format can share during ZeroMQ+pickle + shared_memory_dict[ + layer_name] = MetadataServer._safe_create_shared_memory( + f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) + if layer.use_mla: + assert mla_config is not None + assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim + self.shared_memory[(pp_rank, + tp_rank)] = (shared_memory_dict, layer_size, + layer.dtype, mla_config) + else: + self.shared_memory[(pp_rank, + tp_rank)] = (shared_memory_dict, layer_size, + layer.dtype, None) + if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks: + self.num_cpu_blocks = num_blocks + self.layer = layer + return self.shared_memory[(pp_rank, tp_rank)] + + def post_init(self): + # different processors in data parallel may call multiple times + if hasattr(self, 'cpu_block_manager'): + return + # do shared_memory() at least once + logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}") + assert self.num_cpu_blocks >= 0 + self.cpu_block_manager = CPUKVCacheManager(self.layer, + self.num_cpu_blocks) + self.functions.update({ + "get_matched_num_and_touch": + self.cpu_block_manager.get_matched_num_and_touch, + "allocate_slots": + self.cpu_block_manager.allocate_slots, + "record_request_cache_and_free_slots": + self.cpu_block_manager.record_request_cache_and_free_slots, + "cache_and_free_slots": + self.cpu_block_manager.cache_and_free_slots, + }) + + def serve_step(self): + client_id = self.socket.recv() + _ = self.socket.recv() + raw_msg = self.socket.recv() + try: + func_name, args, kwargs = pickle.loads(raw_msg) + except Exception as e: + response = (None, Exception(f"Invalid request: {str(e)}")) + else: + if func_name in self.functions: + try: + result = self.functions[func_name](*args, **kwargs) + response = (result, None) # type: ignore + except Exception as e: + logger.exception(f"metadata execute error: {e}") + response = (None, e) # type: ignore + else: + response = (None, NameError(f"Function {func_name} not found")) + self.socket.send(client_id, zmq.SNDMORE) # type: ignore + self.socket.send(b"", zmq.SNDMORE) # type: ignore + self.socket.send(pickle.dumps(response)) + + def shutdown(self): + self.socket.close() + self.ctx.term() + socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace( + "ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + for cached in self.shared_memory.values(): + for shm in cached[0].values(): + shm.close() + shm.unlink() + + +class MetadataServerProc: + + @staticmethod + def run_metadata_server(vllm_config: VllmConfig): + if (not vllm_config.cache_config.enable_prefix_caching + or get_cpu_offload_connector(vllm_config) is None): + return + + shutdown_requested = False + + def _signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the worker + # signal.signal(signal.SIGTERM, _signal_handler) + # signal.signal(signal.SIGINT, _signal_handler) + metadata_server: Optional[MetadataServer] = None + try: + metadata_server = MetadataServer(vllm_config) + logger.info("Metadata server started.") + while True: + metadata_server.serve_step() + except SystemExit: + logger.info("Metadata server exiting.") + raise + except Exception as e: + logger.exception(f"Metadata server error: {e}.") + raise e + finally: + if metadata_server is not None: + metadata_server.shutdown() diff --git a/vllm_npu/distributed/device_communicators/__init__.py b/vllm_npu/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/distributed/device_communicators/pyhccl.py b/vllm_npu/distributed/device_communicators/pyhccl.py new file mode 100644 index 0000000..997ab8c --- /dev/null +++ b/vllm_npu/distributed/device_communicators/pyhccl.py @@ -0,0 +1,165 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import logger + +from vllm_npu.distributed.device_communicators.pyhccl_wrapper import ( + HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum, + hcclRedOpTypeEnum, hcclUniqueId) +from vllm_npu.utils import current_stream + + +class PyHcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyHcclCommunicator to. If None, + it will be bind to f"npu:{local_rank}". + library_path: the path to the HCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.HCCL, ( + "PyHcclCommunicator should be attached to a non-HCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + + try: + self.hccl = HCCLLibrary(library_path) + except Exception: + # disable because of missing HCCL library + # e.g. in a non-NPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using pyhccl") + + if isinstance(device, int): + device = torch.device(f"npu:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + if self.rank == 0: + # get the unique id from HCCL + with torch.npu.device(device): + self.unique_id = self.hccl.hcclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = hcclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + + # hccl communicator and stream will use this device + # `torch.npu.device` is a context manager that changes the + # current npu device to the specified one + with torch.npu.device(device): + self.comm: hcclComm_t = self.hccl.hcclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # hccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + hcclRedOpTypeEnum.from_torch(op), self.comm, + aclrtStream_t(stream.npu_stream)) + return out_tensor + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + buffer = buffer_type(tensor.data_ptr()) + else: + buffer = buffer_type(tensor.data_ptr()) + self.hccl.hcclBroadcast(buffer, tensor.numel(), + hcclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, aclrtStream_t(stream.npu_stream)) diff --git a/vllm_npu/distributed/device_communicators/pyhccl_wrapper.py b/vllm_npu/distributed/device_communicators/pyhccl_wrapper.py new file mode 100644 index 0000000..0e7b0cd --- /dev/null +++ b/vllm_npu/distributed/device_communicators/pyhccl_wrapper.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp +from vllm.logger import logger + +from vllm_npu.utils import find_hccl_library + +# export types and functions from hccl to Python === +# for the original hccl definition, please check +# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90 +# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48 + +hcclResult_t = ctypes.c_int +hcclComm_t = ctypes.c_void_p + + +class hcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 4108)] + + +aclrtStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +hcclDataType_t = ctypes.c_int + + +class hcclDataTypeEnum: + hcclInt8 = 0 + hcclInt16 = 1 + hcclInt32 = 2 + hcclFloat16 = 3 + hcclFloat32 = 4 + hcclInt64 = 5 + hcclUint64 = 6 + hcclUint8 = 7 + hcclUint16 = 8 + hcclUint32 = 9 + hcclFloat64 = 10 + hcclBfloat16 = 11 + hcclInt128 = 12 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.hcclInt8 + if dtype == torch.uint8: + return cls.hcclUint8 + if dtype == torch.int32: + return cls.hcclInt32 + if dtype == torch.int64: + return cls.hcclInt64 + if dtype == torch.float16: + return cls.hcclFloat16 + if dtype == torch.float32: + return cls.hcclFloat32 + if dtype == torch.float64: + return cls.hcclFloat64 + if dtype == torch.bfloat16: + return cls.hcclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +hcclRedOp_t = ctypes.c_int + + +class hcclRedOpTypeEnum: + hcclSum = 0 + hcclProd = 1 + hcclMax = 2 + hcclMin = 3 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.hcclSum + if op == ReduceOp.PRODUCT: + return cls.hcclProd + if op == ReduceOp.MAX: + return cls.hcclMax + if op == ReduceOp.MIN: + return cls.hcclMin + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class HCCLLibrary: + exported_functions = [ + # const char* HcclGetErrorString(HcclResult code); + Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]), + + # HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo); + Function("HcclGetRootInfo", hcclResult_t, + [ctypes.POINTER(hcclUniqueId)]), + + # HcclResult HcclCommInitRootInfo( + # uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm); + # note that HcclComm is a pointer type, so the last argument is a pointer to a pointer + Function("HcclCommInitRootInfo", hcclResult_t, [ + ctypes.c_int, + ctypes.POINTER(hcclUniqueId), + ctypes.c_int, + ctypes.POINTER(hcclComm_t), + ]), + + # HcclResult HcclAllReduce( + # void *sendBuf, void *recvBuf, uint64_t count, + # HcclDataType dataType, HcclReduceOp op, HcclComm comm, + # aclrtStream stream); + Function("HcclAllReduce", hcclResult_t, [ + buffer_type, + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + hcclRedOp_t, + hcclComm_t, + aclrtStream_t, + ]), + + # HcclResult HcclBroadcast( + # void *buf, uint64_t count, + # HcclDataType dataType, uint32_t root, + # HcclComm comm, aclrtStream stream); + Function("HcclBroadcast", hcclResult_t, [ + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + ctypes.c_int, + hcclComm_t, + aclrtStream_t, + ]), + + # HcclResult HcclCommDestroy(HcclComm comm); + Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the correspongding directory + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_hccl_library() + + try: + if so_file not in HCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + HCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = HCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load HCCL library from %s. " + "It is expected if you are not running on Ascend NPUs." + "Otherwise, the hccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable HCCL_SO_PATH" + " to point to the correct hccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in HCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in HCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + HCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = HCCLLibrary.path_to_dict_mapping[so_file] + + def hcclGetErrorString(self, result: hcclResult_t) -> str: + return self._funcs["HcclGetErrorString"](result).decode("utf-8") + + def HCCL_CHECK(self, result: hcclResult_t) -> None: + if result != 0: + error_str = self.hcclGetErrorString(result) + raise RuntimeError(f"HCCL error: {error_str}") + + def hcclGetUniqueId(self) -> hcclUniqueId: + unique_id = hcclUniqueId() + self.HCCL_CHECK(self._funcs["HcclGetRootInfo"]( + ctypes.byref(unique_id))) + return unique_id + + def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, + rank: int) -> hcclComm_t: + comm = hcclComm_t() + self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"]( + world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))) + return comm + + def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: hcclComm_t, + stream: aclrtStream_t) -> None: + # `datatype` actually should be `hcclDataType_t` + # and `op` should be `hcclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int, + root: int, comm: hcclComm_t, + stream: aclrtStream_t) -> None: + self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, + root, comm, stream)) + + def hcclCommDestroy(self, comm: hcclComm_t) -> None: + self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm)) + + +__all__ = [ + "HCCLLibrary", + "hcclDataTypeEnum", + "hcclRedOpTypeEnum", + "hcclUniqueId", + "hcclComm_t", + "aclrtStream_t", + "buffer_type", +] diff --git a/vllm_npu/distributed/llmdatadist_c_mgr_connector.py b/vllm_npu/distributed/llmdatadist_c_mgr_connector.py new file mode 100644 index 0000000..d6c33e6 --- /dev/null +++ b/vllm_npu/distributed/llmdatadist_c_mgr_connector.py @@ -0,0 +1,994 @@ +import contextlib +import copy +import json +import math +import os +import threading +import time +from collections import defaultdict +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional, Tuple + +import llm_datadist # type: ignore +import msgspec +import torch +import zmq +from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, + LLMException, LLMRole) +from vllm import envs +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.forward_context import ForwardContext +from vllm.utils import get_ip, logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +import vllm_npu.envs as envs_ascend +from vllm_npu.distributed.utils import get_transfer_timeout_value +from vllm_npu.utils import AscendSocVersion, get_ascend_soc_version + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32 +} + + +class LLMDataDistCMgrEvent(Enum): + ReqForMetadata = 0 + ReqForFinished = 1 + + +class LLMDataDistCMgrAgentMetadata(msgspec.Struct): + super_pod_id: str + server_id: str + device_id: str + device_ip: str + super_device_id: str + cluster_id: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: str + engine_id: str + remote_tp_size: str + + +class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req(self, request_id: str, local_block_ids: list[int], + kv_transfer_params: dict[str, Any]): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_tp_size=kv_transfer_params["remote_tp_size"], + ) + + +class LLMDataDistCMgrConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[ + LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler( + vllm_config, self.engine_id) + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches( + self, + kv_caches: dict[ + str, # type: ignore[override] + Tuple[torch.Tensor]]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + LLMDataDistCMgrConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata, **kwargs) -> None: + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + +class LLMDataDistCMgrConnectorScheduler(): + + def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + self.local_ip = get_ip() + # Can not retrieve the parallel config since it is not initialized. + self.local_dp_rank = None + self.tp_size = None + if vllm_config.parallel_config.data_parallel_external_lb: + dp_rank_local = vllm_config.parallel_config.data_parallel_rank + else: + dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + + self.port = dp_rank_local * tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT + + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}" + ) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, + num_externel_tokens: int): + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}" + ) + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port", "remote_tp_size")): + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning("" \ + f"Invalid KVTransferParams {params}, This request will be discard") + else: + assert num_externel_tokens == 0 + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = LLMDataDistCMgrConnectorMetadata() + + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params) + + meta.reqs_to_send = copy.deepcopy(self._reqs_need_send) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + self._reqs_need_send.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + + params = request.kv_transfer_params + logger.debug( + "LLMDataDistCMgrConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # note: NIXL transfer the full block only, but I don't see any reason to do that, so here + # we just transfer any data that computed from prefill node + # note: there might be some issue on this, check it if there is any unexpected result + computed_block_ids = block_ids + delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(computed_block_ids), request.request_id) + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.local_ip, + remote_port=self.port, + remote_tp_size=str( + self.vllm_config.parallel_config.tensor_parallel_size), + ) + + +class LLMDataDistCMgrConnectorWorker(): + """ + Implementation of Worker side methods + """ + + def __init__(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None + logger.info("Initialize the LLMDataDistCMgrConnectorWorker") + # we assume the local node only contains dp and tp, and tp will not communicate inter-node. + # for any scenario beyond this scope, the functionality of this connector is not guaranteed. + self.local_rank_on_node = get_world_group().rank % ( + vllm_config.parallel_config.data_parallel_size_local * + vllm_config.parallel_config.tensor_parallel_size) + self.local_rank = get_world_group().local_rank + if vllm_config.parallel_config.data_parallel_external_lb: + self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank + else: + self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_rank = get_tp_group().rank_in_group + self.rank = get_world_group().rank + self.local_ip = get_ip() + self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config + self.local_agent_metadata: Optional[ + LLMDataDistCMgrAgentMetadata] = None + self.vllm_config = vllm_config + self.executor = ThreadPoolExecutor(1) + self.thread_lock = threading.Lock() + + self.llm_datadist_role = None + self.llm_datadist_remote_role = None + if self.kv_transfer_config.kv_role == "kv_producer": + self.llm_datadist_role = LLMRole.PROMPT + self.llm_datadist_remote_role = LLMRole.DECODER + elif self.kv_transfer_config.kv_role == "kv_consumer": + self.llm_datadist_role = LLMRole.DECODER + self.llm_datadist_remote_role = LLMRole.PROMPT + else: + raise RuntimeError( + f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}" + ) + + # linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"} + self.linked_cluster: dict[Any, Any] = {} + self.prefill_device_list: list[tuple[int, int]] = [] + self.decode_device_list: list[tuple[int, int]] = [] + global_rank_table = self.read_offline_rank_table() + self.local_agent_metadata = self.read_agent_metadata(global_rank_table) + self.llm_datadist = LLMDataDist(self.llm_datadist_role, + self.local_agent_metadata.cluster_id) + self.init_llm_datadist() + self.finished_reqs: set[str] = set() + self.soc_info = get_ascend_soc_version() + # Set hccl deterministic for model execute + os.environ["HCCL_DETERMINISTIC"] = "true" + self.done_receiving_counts: defaultdict[str, + set[int]] = defaultdict(set) + self.reqs_to_send: dict[str, float] = {} + + def listen_for_agent_metadata_req(self, event: threading.Event): + assert self.local_agent_metadata is not None + port = envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.tp_size + self.tp_rank + url = f"tcp://{envs_ascend.vllm_npu_LLMDD_RPC_IP}:{port}" + msg_encoder = msgspec.msgpack.Encoder() + msg_decoder = msgspec.msgpack.Decoder() + msg_to_send = msg_encoder.encode(self.local_agent_metadata) + logger.debug(f"Start to listen to address: {url}") + logger.debug( + f"The local agent metadata have {len(msg_to_send)} bytes here") + logger.info( + f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers" + ) + with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined] + event.set() + while True: + identity, _, msg = sock.recv_multipart() + event_msg, decode_msg = msg_decoder.decode(msg) + event_msg = LLMDataDistCMgrEvent(event_msg) + if event_msg == LLMDataDistCMgrEvent.ReqForMetadata: + if "cluster_id" in decode_msg: + decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) + logger.info( + f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" + ) + sock.send_multipart((identity, b"", msg_to_send)) + self.add_remote_agent(decode_msg) + else: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + ) + elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: + finished_req_id = decode_msg[0] + with self.thread_lock: + logger.debug( + f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" + ) + if finished_req_id in self.reqs_to_send: + self.finished_reqs.add(finished_req_id) + del self.reqs_to_send[finished_req_id] + sock.send_multipart( + (identity, b"", b"receiving decode finished")) + else: + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !" + ) + + def init_llm_datadist(self): + assert self.local_agent_metadata is not None + llm_config = LLMConfig() + llm_config.device_id = self.local_rank + llm_config.sync_kv_timeout = get_transfer_timeout_value() + llm_config.enable_switch_role = True + llm_config.enable_cache_manager = True + llm_config.enable_remote_cache_accessible = True + llm_config_options = llm_config.generate_options() + self.llm_datadist.init(llm_config_options) + self.cache_manager = self.llm_datadist.cache_manager + logger.info( + f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}" + ) + + def read_offline_rank_table(self): + assert ( + envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + ), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH" + rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + with open(rank_table_path, "r", encoding="utf-8") as f: + global_rank_table = json.load(f) + decode_device_list = global_rank_table["decode_device_list"] + for decode_device in decode_device_list: + server_id = decode_device["server_id"] + device_id = decode_device["device_id"] + self.decode_device_list.append((server_id, device_id)) + prefill_device_list = global_rank_table["prefill_device_list"] + for prefill_device in prefill_device_list: + server_id = prefill_device["server_id"] + device_id = prefill_device["device_id"] + self.prefill_device_list.append((server_id, device_id)) + + # global_rank_table = json.dumps(global_rank_table) + return global_rank_table + + @staticmethod + def _get_visible_devices() -> Callable[[str], bool]: + """ + Return a test function that check if the given device ID is visible. + i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id. + """ + visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "") + if not visible_devices: + return lambda device_id: True + visible_device_list = visible_devices.split(",") + return lambda device_id: device_id in visible_device_list + + def read_agent_metadata(self, global_rank_table): + device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices() + devices_type_list = [] + agent_metadata = None + if self.llm_datadist_role == LLMRole.PROMPT: + devices_type_list.append("prefill_device_list") + elif self.llm_datadist_role == LLMRole.DECODER: + devices_type_list.append("decode_device_list") + else: + devices_type_list.append("prefill_device_list") + devices_type_list.append("decode_device_list") + for device_type in devices_type_list: + device_list = global_rank_table[device_type] + device_list = [ + d for d in device_list if d.get("server_id") == self.local_ip + and device_filter(d.get("device_id", "")) + ] + if len(device_list) <= self.tp_rank: + continue + device_info = device_list[self.tp_rank] + super_pod_id_ = device_info.get("super_pod_id", None) + server_id_ = device_info["server_id"] + device_id_ = device_info["device_id"] + device_ip_ = device_info["device_ip"] + super_device_id_ = device_info.get("super_device_id", None) + cluster_id_ = int(device_info["cluster_id"]) + agent_metadata = LLMDataDistCMgrAgentMetadata( + super_pod_id=super_pod_id_, + server_id=server_id_, + device_id=device_id_, + device_ip=device_ip_, + super_device_id=super_device_id_, + cluster_id=cluster_id_, + ) + assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table" + return agent_metadata + + def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + assert len(first_kv_cache_tuple) > 1 + assert self.local_agent_metadata is not None + kv_cache_dtype = first_kv_cache.dtype + self.use_mla: bool = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sparse: bool = len(first_kv_cache_tuple) == 3 + # MLA case. [2 (k_normed, k_pe), num_blocks, ...] + # SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...] + # MHA case. [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + self.block_len = math.prod(block_shape) + self.cache_addr: list[int] = [] + alignment = 2 * 1024 * 1024 + if self.use_mla: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + k_normed = None + k_pe = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe = cache_or_caches[0], cache_or_caches[1] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe) + self.cache_key = (cache_key_k_normed, cache_key_k_pe) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + self.cache = (cache_k_normed, cache_k_pe) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + elif self.use_sparse: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + cache_k_idx_addr_list = [] + k_normed = None + k_pe = None + k_idx = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[ + 1], cache_or_caches[2] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + cache_k_idx_addr_list.append(k_idx.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list, + cache_k_idx_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_idx = CacheDesc( + len(self.cache_addr[2]), [*k_idx.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + cache_key_k_idx = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=2) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe, + cache_desc_k_idx) + self.cache_key = (cache_key_k_normed, cache_key_k_pe, + cache_key_k_idx) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + cache_k_idx = self.cache_manager.register_blocks_cache( + self.cache_desc[2], self.cache_addr[2], self.cache_key[2]) + self.cache = (cache_k_normed, cache_k_pe, cache_k_idx) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + else: + for cache_or_caches in kv_caches.values(): + for cache in cache_or_caches: + base_addr = cache.data_ptr() + assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + self.cache_addr.append(base_addr) + # register paged kv cache into the llm_cache manager + self.cache_desc = CacheDesc( + len(self.cache_addr), [*cache.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + self.cache_key = BlocksCacheKey( + cluster_id=int(self.local_agent_metadata.cluster_id)) + logger.info( + f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}" + ) + try: + self.cache = self.cache_manager.register_blocks_cache( + self.cache_desc, self.cache_addr, self.cache_key) + logger.info( + "LLMDataDistCMgrConnectorWorker: End of register Paged Cache." + ) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + self.ready_event = threading.Event() + self.metadata_agent_listener_t = threading.Thread( + target=self.listen_for_agent_metadata_req, + args=(self.ready_event, ), + daemon=True, + name="metadata_agent_listener") + self.metadata_agent_listener_t.start() + self.ready_event.wait() + + def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata): + futures = [] + for req_id, meta in metadata.requests.items(): + logger.debug(f"Start to transmit {req_id}") + future = self.executor.submit( + self._read_blocks, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_ip=meta.remote_host, + remote_port=int(meta.remote_port), + remote_engine_id=meta.engine_id, + request_id=req_id, + remote_tp_size=meta.remote_tp_size, + ) + futures.append(future) + + def handle_exception(future): + if future.exception(): + logger.error(f"KV transfer task failed: {future.exception()}") + + for future in futures: + future.add_done_callback(handle_exception) + self.reqs_to_send.update(metadata.reqs_to_send) + + def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: + assert self.local_agent_metadata is not None + remote_cluster_id = metadata.cluster_id + if remote_cluster_id in self.linked_cluster: + logger.debug( + f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection" + ) + return remote_cluster_id + remote_super_pod_id = metadata.super_pod_id + remote_server_id = metadata.server_id + is_same_server = remote_server_id == self.local_agent_metadata.server_id + is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id + if self.llm_datadist_role == LLMRole.PROMPT: + prefill_metadata = self.local_agent_metadata + decode_metadata = metadata + else: + prefill_metadata = metadata + decode_metadata = self.local_agent_metadata + comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}" + cluster_rank_info = { + prefill_metadata.cluster_id: 0, + decode_metadata.cluster_id: 1 + } + rank_table = {} + rank_table["version"] = "1.2" + rank_table["server_count"] = "1" if is_same_server else "2" + rank_table["status"] = "completed" + + # generate server_list for rank table + rank_table["server_list"] = [] # type: ignore[assignment] + decode_server_device_info = None + prefill_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", prefill_metadata.device_id + ), ("device_ip", prefill_metadata.device_ip + ), ("super_device_id", + prefill_metadata.super_device_id), ("rank_id", "0")] + if v is not None + }], + "server_id": + prefill_metadata.server_id + } + if is_same_server: + prefill_server_device_info["device"].append( # type: ignore[attr-defined] + { + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }) + else: + decode_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }], + "server_id": + decode_metadata.server_id + } + rank_table["server_list"].append( # type: ignore[attr-defined] + prefill_server_device_info) + if decode_server_device_info is not None: + rank_table["server_list"].append( # type: ignore[attr-defined] + decode_server_device_info) + + if self.soc_info == AscendSocVersion.A3: + # generate super_pod_list for rank table + super_pod_list = [] + prefill_super_pod_info = { + "super_pod_id": prefill_metadata.super_pod_id, + "server_list": [{ + "server_id": prefill_metadata.server_id + }], + } + if is_same_pod and not is_same_server: + prefill_super_pod_info[ + "server_list"].append( # type: ignore[attr-defined] + {"server_id": decode_metadata.server_id}) + super_pod_list.append(prefill_super_pod_info) + if not is_same_pod: + decode_super_pod_id = { + "super_pod_id": decode_metadata.super_pod_id, + "server_list": [{ + "server_id": decode_metadata.server_id + }], + } + super_pod_list.append(decode_super_pod_id) + rank_table[ + "super_pod_list"] = super_pod_list # type: ignore[assignment] + logger.info( + f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}" + ) + logger.info(f"rank table \n{rank_table}") + logger.info(f"comm name: {comm_name}") + logger.info(f"cluster rank info: {cluster_rank_info}") + comm_id = self.llm_datadist.link(comm_name, cluster_rank_info, + json.dumps(rank_table)) + while True: + ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id) + if ret == llm_datadist.RegisterMemStatus.OK: + logger.info( + f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}" + ) + break + elif ret == llm_datadist.RegisterMemStatus.FAILED: + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}" + ) + time.sleep(1) + logger.info("Checking query_register_mem_status again") + self.linked_cluster.update({remote_cluster_id: comm_id}) + logger.info(f"cached linked cluster: {self.linked_cluster}") + logger.info( + f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !" + ) + return remote_cluster_id + + def remove_remote_agent(self, cluster_id: int): + if cluster_id not in self.linked_cluster: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list" + ) + comm_id = self.linked_cluster[cluster_id] + try: + self.llm_datadist.unlink(comm_id) + self.linked_cluster.pop(cluster_id) + except LLMException: + logger.error( + f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment" + ) + logger.info( + f"Successfully remove remote client with cluster id {cluster_id} !" + ) + + def connect_to_remote_agent(self, host: str, port: int) -> int: + url = f"tcp://{host}:{port}" + logger.debug(f"Querying metadata from url: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata]) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + logger.info("Try request remote metadata from socket......") + sock.send(msg_send) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder() + metadata = decoder.decode(metadata_bytes) + metadata = LLMDataDistCMgrAgentMetadata(**metadata) + logger.info(f"recving metadata: {metadata}") + cluster_id = self.add_remote_agent(metadata) + return cluster_id + + def send_finish_to_remote(self, host: str, ports: list[int], request_id): + for port in ports: + url = f"tcp://{host}:{port}" + logger.debug(f"Sending finished to remote: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + try: + sock.send(msg_send) + logger.debug( + f"Request id {request_id} finished message send to remote {url}" + ) + _ = sock.recv() + except Exception as e: + logger.error( + f"Failed to send reqest_id {request_id} to prefill: {e}" + ) + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_ip: str, + remote_port: int, + remote_engine_id: str, + request_id: str, + remote_tp_size: str, + ): + # if remote_ip not in self.linked_cluster: + tp_offset = self.tp_rank % int(remote_tp_size) + remote_cluster_id = self.connect_to_remote_agent( + remote_ip, remote_port + tp_offset) + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + return + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + logger.info(f"remote cluster id is: {remote_cluster_id}") + if self.use_mla: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + elif self.use_sparse: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + remote_cache_key_k_idx = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=2) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_idx, + self.cache[2], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + else: + remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key, + self.cache, # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + remote_ports = list( + range(remote_port + self.tp_rank, + remote_port + int(remote_tp_size), self.tp_size)) + self.send_finish_to_remote(remote_ip, remote_ports, request_id) + with self.thread_lock: + self.finished_reqs.add(request_id) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requuests.""" + now = time.perf_counter() + with self.thread_lock: + while self.reqs_to_send: + req_id, expires = next(iter(self.reqs_to_send.items())) + if now < expires: + break + logger.warning( + "Some requests in prefill node fail to receive KV Cache transfer done signal. " + "If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " + ) + if req_id in self.reqs_to_send: + self.finished_reqs.add(req_id) + del self.reqs_to_send[req_id] + req_ids_to_ret = copy.deepcopy(self.finished_reqs) + self.finished_reqs.clear() + if self.llm_datadist_role == LLMRole.PROMPT: + return req_ids_to_ret, None + else: + return None, req_ids_to_ret + + +# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + """Context manager for a ZMQ socket""" + + ctx: Optional[zmq.Context] = None # type: ignore[name-defined] + try: + ctx = zmq.Context() # type: ignore[attr-defined] + + if socket_type == zmq.ROUTER: # type: ignore[attr-defined] + socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined] + socket.bind(addr) + elif socket_type == zmq.REQ: # type: ignore[attr-defined] + socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined] + socket.connect(addr) + else: + raise ValueError(f"Unexpected socket type: {socket_type}") + + yield socket + finally: + if ctx is not None: + ctx.destroy(linger=0) \ No newline at end of file diff --git a/vllm_npu/distributed/mooncake/__init__.py b/vllm_npu/distributed/mooncake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/distributed/mooncake/config_data.py b/vllm_npu/distributed/mooncake/config_data.py new file mode 100644 index 0000000..745d911 --- /dev/null +++ b/vllm_npu/distributed/mooncake/config_data.py @@ -0,0 +1,449 @@ +import array +import hashlib +import json +import os +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.utils import cdiv, logger +from vllm.v1.core.sched.output import NewRequestData + + +@dataclass +class MooncakeEngineMetadata: + """name of the LLM model""" + + model_name: str + """ world size when running under a distributed setting """ + world_size: int + """ worker id when running under a distributed setting """ + worker_id: int + """ the format of kv tensors """ + kv_dtype: torch.dtype + """ the shape of kv tensors """ + """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ + kv_shape: tuple[int, int, int, int, int] + block_size: int = 128 + """ whether use MLA""" + use_mla: bool = False + + +@dataclass(order=True) +class MooncakeEngineKey: + model_name: str + world_size: int + worker_id: int + chunk_hash: str + + def __hash__(self): + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + )) + + def to_string(self): + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}") + + def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: + """Split the key into multiple keys for each layer""" + keys = [] + for layer_id in range(num_layers): + keys.append( + LayerMooncakeEngineKey( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + layer_id, + )) + return keys + + def to_dict(self): + # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. + return { + "__type__": "CacheEngineKey", + "model_name": self.model_name, + "world_size": self.world_size, + "worker_id": self.worker_id, + "chunk_hash": self.chunk_hash, + } + + @staticmethod + def from_dict(d): + return MooncakeEngineKey( + model_name=d["model_name"], + world_size=d["world_size"], + worker_id=d["worker_id"], + chunk_hash=d["chunk_hash"], + ) + + +@dataclass(order=True) +class LayerMooncakeEngineKey(MooncakeEngineKey): + """A key for the layer cache engine""" + + layer_id: int + + def __hash__(self): + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + self.layer_id, + )) + + def to_string(self): + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") + + +class ChunkedTokenDatabase(): + + def __init__( + self, + metadata: MooncakeEngineMetadata, + ): + self.metadata = metadata + + def _make_key_by_hash(self, + chunk_hash: str, + layer_id: Optional[int] = None): + assert self.metadata is not None + return MooncakeEngineKey( + self.metadata.model_name, + self.metadata.world_size, + self.metadata.worker_id, + chunk_hash, + ) + + def _hash( + self, + tokens: Union[torch.Tensor, List[int]], + prefix_hash: str, + ) -> str: + # TODO: change it to a more efficient hash function + if isinstance(tokens, torch.Tensor): + tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() + elif isinstance(tokens, list): + tokens_bytes = array.array("I", tokens).tobytes() + return hashlib.sha256(prefix_hash.encode("ascii") + + tokens_bytes).hexdigest() + + def _chunk_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + ) -> Iterable[Union[torch.Tensor, List[int]]]: + """ + Chunk the tokens into chunks of size self.metadata.block_size. + + :param tokens: the input tokens, with shape [seq_len] + device: the target device after chunking + + :return: a generator of chunks of tokens, each with + shape [metadata.block_size] + """ + for i in range(0, len(tokens), self.metadata.block_size): + yield tokens[i:i + self.metadata.block_size] + + def _prefix_hash( + self, + token_chunks: Iterable[Union[torch.Tensor, List[int]]], + ) -> Iterable[str]: + prefix_hash = '' + for token_chunk in token_chunks: + prefix_hash = self._hash(token_chunk, prefix_hash) + yield prefix_hash + + def process_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + mask: Optional[torch.Tensor] = None, + ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: + """Process the tokens and return the corresponding cache engine keys. + + :param Union[torch.Tensor, List[int]] tokens: The tokens to process. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched, + and the Falses will ALWAYS be at the PREFIX of the tensor. + + :param bool make_key: Whether to make the cache engine key or not. + If False, the hash value will be returned instead. + + :returns: A iterable of tuples with three elements. The first element + is the start index of the tokens for the key. The second element + is the end index of the tokens for the key. The third element is + the cache engine key (or hash) for the tokens. + + :raises: ValueError if the number of Falses in the mask is not a + multiple of the chunk size. + """ + if mask is not None: + num_falses = mask.numel() - mask.long().sum().item() + else: + num_falses = 0 + + if num_falses % self.metadata.block_size != 0: + raise ValueError( + "The number of Falses in the mask is not a multiple of the chunk size." + ) + total_len = len(tokens) + + token_chunks = self._chunk_tokens(tokens) + prefix_hashes = self._prefix_hash(token_chunks) + + start_idx = 0 + for chunk_id, hash_val in enumerate(prefix_hashes): + start_idx = chunk_id * self.metadata.block_size + end_idx = min(start_idx + self.metadata.block_size, total_len) + if start_idx < num_falses: + continue + else: + yield start_idx, end_idx, self._make_key_by_hash(hash_val) + + +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in mooncake + mooncake_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class SaveSpec: + # Skip already saved tokens + skip_leading_tokens: int + # Whether the scheduler allow us to save the tokens + can_save: bool + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # The token ids that has been scheduled so far + token_ids: list[int] + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been savd + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + + """ + # vLLM 0.9.0 update: request.block_ids changed from list[int] to + # list[list[int]] + # Need to check the type of request.block_ids + + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + unfolded_block_ids = new_request.block_ids[0].copy() + + return RequestTracker( + req_id=new_request.req_id, + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[tuple[list[int], ...], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_ids.extend(new_token_ids) + + if len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_ids: torch.Tensor + + block_ids: list[int] + # # Slot mapping if exchange for block_id + # slot_mapping: torch.Tensor + # Skip save or not + save_spec: Optional[SaveSpec] = None + # load_spec + load_spec: Optional[LoadSpec] = None + + is_last_chunk: Optional[bool] = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + load_spec: Optional[LoadSpec] = None, + skip_save: Optional[bool] = False, + is_last_chunk: Optional[bool] = None, + discard_partial_chunks: bool = True, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + skip_save (bool): whether to skip the save operation. + discard_partial_chunks (bool): whether to discard partial chunks. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_ids = tracker.token_ids + input_token_len = len(input_token_ids) + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + skip_leading_tokens = tracker.num_saved_tokens + chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * + block_size if discard_partial_chunks else 0) + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ((input_token_len // block_size * block_size) + if discard_partial_chunks else input_token_len) + + skip_save = skip_save or num_tokens_to_save < chunk_boundary + if skip_save and load_spec is None: + return None + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + save_spec = SaveSpec(skip_leading_tokens, not skip_save) + + # Calculate the token ids and slot mappings for load and save + # OPTIMIZATION: pre-allocate the buffer for token ids and block ids + token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] + + # # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.mooncake_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + logger.debug( + f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" + ) + return ReqMeta( + req_id=tracker.req_id, + token_ids=token_ids, + block_ids=tracker.allocated_block_ids, + save_spec=save_spec, + load_spec=load_spec, + is_last_chunk=is_last_chunk, + ) + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + + def __init__(self, unfinished_request_ids): + self.requests = [] + self.unfinished_request_ids = unfinished_request_ids + + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +@dataclass +class LasyerMultiBlockReqMeta: + req_id: str + keys: List[LayerMooncakeEngineKey] + starts: List[int] + ends: list[int] + block_ids: list[int] + layer_id: int + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + use_ascend_direct: bool + + @staticmethod + def from_file(file_path: str) -> "MooncakeStoreConfig": + with open(file_path) as file: + config = json.load(file) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", 3355443200), + local_buffer_size=config.get("local_buffer_size", 1073741824), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + use_ascend_direct=config.get("use_ascend_direct", False)) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + config_path = os.getenv("MOONCAKE_CONFIG_PATH") + if not config_path: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) \ No newline at end of file diff --git a/vllm_npu/distributed/mooncake/kv_transfer.py b/vllm_npu/distributed/mooncake/kv_transfer.py new file mode 100644 index 0000000..83df46e --- /dev/null +++ b/vllm_npu/distributed/mooncake/kv_transfer.py @@ -0,0 +1,293 @@ +import queue +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional + +import torch +from vllm.utils import logger + +from vllm_npu.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta) +from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore + + +class KVTransferThread(threading.Thread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, name: str): + super().__init__(daemon=True, name=name) + self.tp_rank = tp_rank + self.tp_size = tp_size + self.m_store = m_store + self.ready_event = ready_event + self.kv_caches_base_addr = local_kv_caches_base_addr + self.block_len = block_len + self.token_database = token_database + self.block_size = block_size + self.done_task_lock = threading.Lock() + # TODO(jianzs): find a better way to detect MLA. + self.use_mla = len(block_len) == 2 + + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + self.finished_requests: set[str] = set() + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], + layer_id: int): + block_id = block_ids[start // self.block_size] + if self.use_mla: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[1] + length_k = int(self.block_len[0] / self.block_size * (end - start)) + length_v = int(self.block_len[1] / self.block_size * (end - start)) + size_list = [length_k, length_v] + else: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[0] + length = int(self.block_len[0] / self.block_size * (end - start)) + size_list = [length, length] + addr_list = [addr_k, addr_v] + return addr_list, size_list + + def add_request( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + is_last_chunk: Optional[bool] = None, + current_event: Optional[torch.npu.Event] = None, + ) -> torch.Tensor: + req = ({ + "req_id": req_id, + "tokens": tokens, + "block_ids": block_ids, + "mask": mask, + "is_last_chunk": is_last_chunk, + "current_event": current_event, + }) + self.request_queue.put(req) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + def set_finished_request(self, req_id): + with self.done_task_lock: + self.finished_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + pass + + +class KVCacheStoreSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheSendingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + is_last_chunk = req_meta["is_last_chunk"] + current_event = req_meta["current_event"] + if self.m_store.config.use_ascend_direct: + addr_list = [] + size_list = [] + key_list = [] + blockIds = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, block_id = self.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + blockIds.append(block_id) + if key_list: + """ + Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. + This issue will be fixed in CANN version 8.5.rc1. + You can manually build the master branch of the project at https://gitcode.com/cann/hixl + to resolve this issue before the 8.5.RC1 release. + """ + if current_event is not None: + current_event.synchronize() + self.m_store.put_batch(key_list, addr_list, size_list, blockIds) + else: + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + if current_event is not None: + current_event.synchronize() + self.m_store.put(key, addr, size) + if is_last_chunk: + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreRecvingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + if self.m_store.config.use_ascend_direct: + addr_list = [] + size_list = [] + key_list = [] + blockIds = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, block_id = self.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + blockIds.append(block_id) + self.m_store.get_batch(key_list, addr_list, size_list, blockIds) + else: + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.get(key, addr, size) + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + num_layers: int): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerSendingThread") + self.final_layer_id = num_layers - 1 + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) + self.m_store.put(key, addr, size) + if req_meta.layer_id == self.final_layer_id: + self.set_finished_request(req_meta.req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + get_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerRecvingThread") + self.get_event = get_event + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) + self.m_store.get(key, addr, size) + self.request_queue.task_done() + self.get_event.set() diff --git a/vllm_npu/distributed/mooncake/mooncake_engine.py b/vllm_npu/distributed/mooncake/mooncake_engine.py new file mode 100644 index 0000000..05c1c1a --- /dev/null +++ b/vllm_npu/distributed/mooncake/mooncake_engine.py @@ -0,0 +1,639 @@ +# Standard +import math +import threading +import time +from typing import Generator, List, Optional, Union + +# Third Party +import torch +from vllm.config import VllmConfig +from vllm.utils import get_kv_cache_torch_dtype, logger + +from vllm_npu.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, + MooncakeEngineMetadata) +from vllm_npu.distributed.mooncake.kv_transfer import ( + KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, + KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) +from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore + + +class MooncakeEngine: + #The main class for the cache engine. + + def __init__( + self, + vllm_config: VllmConfig, + use_layerwize: bool, + ): + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.use_mla = False + if (hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla): + self.use_mla = True + self.use_layerwise = use_layerwize + self.tp_rank = parallel_config.rank + self.tp_size = parallel_config.tensor_parallel_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "load_async", False) + self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "register_buffer", False) + self.block_size = vllm_config.cache_config.block_size + self.current_layer = 0 + # self.use_mla = first_kv_cache_tuple[0].size( + # -1) != first_kv_cache_tuple[1].size(-1) + self.num_layers = model_config.get_num_layers(parallel_config) + self.block_size = vllm_config.cache_config.block_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_dtype = get_kv_cache_torch_dtype( + vllm_config.cache_config.cache_dtype, model_config.dtype) + self.hidden_dim_size = num_kv_head * head_size + if self.use_mla: + kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) + else: + kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, + head_size) + self.metadata = MooncakeEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + kv_dtype, + kv_shape, + self.block_size, + self.use_mla, + ) + + self.token_database = ChunkedTokenDatabase(self.metadata) + + self.m_store = Mooncakestore(parallel_config) + + self.kv_send_thread: Optional[KVTransferThread] = None + self.kv_recv_thread: Optional[KVTransferThread] = None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + + # TODO(tms): Find a more robust way to detect and handle MLA + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + self.kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + if self.register_buffer: + region_len = self.num_blocks * self.block_len[i % 2] + self._register(base_addr, region_len) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + if self.register_buffer: + region_len = self.num_blocks * self.block_len[0] + self._register(base_addr, region_len) + + if self.use_layerwise: + self.get_event = threading.Event() + if self.kv_role in ['kv_producer', 'kv_both']: + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreLayerSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending, + self.num_layers) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreLayerRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event, self.get_event) + self.kv_recv_thread.start() + ready_event.wait() + else: + if self.kv_role in ['kv_producer', 'kv_both']: + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending) + self.kv_send_thread.start() + if self.load_async: + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event) + self.kv_recv_thread.start() + ready_event.wait() + + def _register(self, ptr, length): + logger.debug( + "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " + "block_lens=%s", ptr, length, self.num_blocks, self.block_len) + try: + self.m_store.register_buffer(ptr, length) + except Exception as e: + raise RuntimeError( + f"Mooncake memory registration failed. Error is: {e}") + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + self.current_layer = 0 + self.layerwise_retrievers = [] + for request in metadata.requests: + load_spec = request.load_spec + if load_spec is None or not load_spec.can_load: #load =0 + continue + tokens = request.token_ids + req_id = request.req_id + if (load_spec.mooncake_cached_tokens % self.block_size + != 0) and (load_spec.mooncake_cached_tokens + == tokens.shape[0] - 1): + tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] + else: + tokens = tokens[:request.load_spec.mooncake_cached_tokens] + masked_token_count = (request.load_spec.vllm_cached_tokens // + self.block_size * self.block_size) + token_mask = torch.ones_like(tokens, dtype=torch.bool) + token_mask[:masked_token_count] = False + if self.use_layerwise: + layerwise_retriever = self.retrieve_layer( + req_id, + tokens, + request.block_ids, + token_mask, + ) + next(layerwise_retriever) # first layer load + self.layerwise_retrievers.append(layerwise_retriever) + else: + if self.load_async: + self.kv_recv_thread.add_request( # type: ignore[union-attr] + req_id, + tokens, + request.block_ids, + token_mask, + ) + else: + if self.m_store.config.use_ascend_direct: + addr_list = [] + size_list = [] + key_list = [] + blockIds = [] + for start, end, key in self.token_database.process_tokens( + tokens, token_mask): + addr, size, block_id = self.prepare_value( + start, end, request.block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + blockIds.append(block_id) + self.m_store.get_batch(key_list, addr_list, size_list, + blockIds) + else: + for start, end, key in self.token_database.process_tokens( + tokens, token_mask): + addr, size, _ = self.prepare_value( + start, end, request.block_ids) + self.m_store.get(key, addr, size) + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def wait_for_layer_load(self) -> None: + """MooncakeConnector does not do layerwise saving.""" + for layerwise_retriever in self.layerwise_retrievers: + ret_token_mask = next(layerwise_retriever) + if self.current_layer == self.num_layers - 1: + assert ret_token_mask is not None + num_retrieved_tokens = ret_token_mask.sum().item() + logger.info(f"Retrieved {num_retrieved_tokens} tokens") + + def save_kv_layer(self, + connector_metadata: MooncakeConnectorMetadata) -> None: + """MooncakeConnector does not save explicitly.""" + if self.current_layer == 0: + self.layerwise_storers = [] + current_event = None + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + current_event = torch.npu.Event() + current_event.record() + break + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + # TODO: whether need to remov saveThread + # no lookup, skipmask + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) + continue # skip this request + + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + layerwise_storer = self.store_layer( + req_id, + token_ids, + mask=store_mask, + block_ids=request.block_ids, + ) + self.layerwise_storers.append(layerwise_storer) + for layerwise_storer in self.layerwise_storers: + try: + next(layerwise_storer) + except Exception: + raise + self.current_layer = self.current_layer + 1 + + def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): + """MooncakeConnector does not save explicitly.""" + current_event = None + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + current_event = torch.npu.Event() + current_event.record() + break + + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) + continue # skip this request + + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + self.kv_send_thread.add_request( # type: ignore[union-attr] + req_id, + token_ids, + request.block_ids, + store_mask, + request.is_last_chunk, + current_event, + ) + + def retrieve_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[Optional[torch.Tensor], None, None]: + """ + Retrieve the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the KV transfer which + will be passed into the npu_transfer. + + return: A generator that yields Optional[torch.Tensor]. The tensor will + be the boolean mask indicating which tokens are retrieved and will + only be returned in the last iteration. + """ + + if mask is not None: + num_required_tokens = torch.sum(mask).item() + else: + num_required_tokens = len(tokens) + + ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") + + starts = [] + ends = [] + keys = [] + first_flag = True + for start, end, key in self.token_database.process_tokens( + tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) + ret_mask[start:end] = True + + if keys: + # Transpose the keys into layer major format + keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + if not first_flag: + is_finish = self.get_event.wait(timeout=3) #try---cache + if not is_finish: + logger.info("Layerwise get failed") + self.get_event.clear() + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) + self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] + first_flag = False + yield None + else: + # If no cache are found, we still need to yield to avoid + # `StopIteration` + for layer_id in range(self.num_layers): + yield None + + retrieved_tokens = torch.sum(ret_mask) + logger.debug(f"Retrieved {retrieved_tokens} " + f"out of {num_required_tokens} " + f"out of total {len(tokens)} tokens") + + yield ret_mask + + def store_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[None, None, None]: + """ + Store the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the storage backend which + will be passed into the gpu_connector. + + return: A generator that yields None. In the first iteration, the + generator allocates the memory objects for all layers and moves + the KV cache of the first layer from GPU to CPU. In the next + iterations, it moves the KV cache of layer i from GPU to the memory + objects (on CPU) and puts the memory objects of layer i-1 to the + storage backends. In the last iteration, it puts the memory objects + of the last layer to the storage backends. + """ + + if mask is not None: + num_stored_tokens = torch.sum(mask).item() + else: + num_stored_tokens = len(tokens) + + starts = [] + ends = [] + keys = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) #[block_num,layer_num] + + if keys: + keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) + self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] + yield + else: + for layer_id in range(self.num_layers): + yield + logger.debug( + f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) + + done_recving = ( + self.kv_recv_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.load_async else set()) + + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), + self.tp_rank) + return done_sending, done_recving + + def wait_layer_transfer_finish(self): + time.sleep(10) + pass + + def lookup( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + :param tokens: the input tokens, with shape [seq_len] + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + keys = [] + try: + if use_layerwise: + for start, end, key in self.token_database.process_tokens( + tokens): + keys_multi_layer = key.split_layers(self.num_layers) + for item in keys_multi_layer: + keys.append(item.to_string()) + # batch is_exists + ress = self.m_store.batch_exists(keys) + res = 1 + for value in ress: + if value != 1: + res = 0 + break + if res == 1: + continue + else: + return start + else: + starts = [] + for start, end, key in self.token_database.process_tokens( + tokens): + keys.append(key.to_string()) + starts.append(start) + res = self.m_store.batch_exists( + keys) # type: ignore[assignment] + for index, value in enumerate(res): # type: ignore[arg-type] + if value != 1: + return starts[index] + # all tokens where found, return the maximal end + except Exception as e: + logger.error(f"Remote connection failed in contains: {e}") + return start + return end + + def lookup_scheduler( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + :param tokens: the input tokens, with shape [seq_len] + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + keys = [] + try: + if use_layerwise: + for start, end, key in self.token_database.process_tokens( + tokens): + keys_multi_layer = key.split_layers(self.num_layers) + for item in keys_multi_layer: + keys.append(item.to_string()) + # batch is_exists + ress = self.m_store.batch_exists(keys) + res = 1 + for value in ress: + if value != 1: + res = 0 + break + if res == 1: + continue + else: + return start + else: + starts = [] + for start, end, key in self.token_database.process_tokens( + tokens): + keys.append(key.to_string()) + starts.append(start) + multi_tp_keys = keys[:] + for i in range(1, self.tp_size): + for item in keys: + new_str = item.replace( # type: ignore[attr-defined] + "@0", f"@{i}", 1) + multi_tp_keys.append(new_str) + res = self.m_store.batch_exists( + multi_tp_keys) # type: ignore[assignment] + num_block = len(keys) + multi_tp_values = [ + res[i * num_block:(i + 1) * + num_block] # type: ignore[index] + for i in range(self.tp_size) + ] + index = self.find_min_first_non_one_index(multi_tp_values) + if index != -1: + return starts[index] + # all tokens where found, return the maximal end + except Exception as e: + logger.error(f"Remote connection failed in contains: {e}") + return start + return end + + def find_min_first_non_one_index(self, arr): + try: + return min(idx for row in arr for idx, val in enumerate(row) + if val != 1) + except ValueError: + return -1 + + def close(self) -> None: + """Close the cache engine and free all the resources""" + self.m_store.close() diff --git a/vllm_npu/distributed/mooncake/mooncake_store.py b/vllm_npu/distributed/mooncake/mooncake_store.py new file mode 100644 index 0000000..1aec15a --- /dev/null +++ b/vllm_npu/distributed/mooncake/mooncake_store.py @@ -0,0 +1,126 @@ +# Standard +import os + +# Third Party +from mooncake.store import ReplicateConfig # type: ignore +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.utils import get_ip, logger + +from vllm_npu.distributed.mooncake.config_data import MooncakeEngineKey +from vllm_npu.distributed.mooncake.transfer_engine import get_global_te + +from .config_data import MooncakeStoreConfig + +METADATA_BYTES_LEN = 24 +BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) + + +class Mooncakestore(): + + def __init__(self, parallel_config: ParallelConfig): + try: + from mooncake.store import MooncakeDistributedStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + tp_rank = get_tensor_model_parallel_rank() + tp_size = parallel_config.tensor_parallel_size + dp_rank = parallel_config.data_parallel_rank_local + all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) + if not all_device_ids: + device_ids_list = list( + range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) + else: + device_ids_list = list(map(int, all_device_ids.split(','))) + assert len(device_ids_list) > tp_rank + device_id = device_ids_list[tp_rank] + self.config = MooncakeStoreConfig.load_from_env() + self.store = MooncakeDistributedStore() + if self.config.protocol == "ascend" and not self.config.use_ascend_direct: + local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \ + ":npu_" + str(device_id) + ret = self.store.setup(local_hostname, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address) + else: + local_hostname = get_ip() + transfer_engine = get_global_te(local_hostname, device_name=None) + self.local_seg = local_hostname + ":" + str( + transfer_engine.get_rpc_port()) + ret = self.store.setup(self.local_seg, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + transfer_engine.get_engine()) + if ret != 0: + msg = "Initialize mooncake failed." + logger.error(msg) + raise RuntimeError(msg) + + def exists(self, key: MooncakeEngineKey) -> bool: + return self.store.is_exist(key.to_string()) == 1 + + def batch_exists(self, keys: list[str]) -> list[int]: + return self.store.batch_is_exist(keys) + + def register_buffer(self, ptr, length): + return self.store.register_buffer(ptr, length) + + def get_batch(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]], block_ids: list[int]): + try: + res = self.store.batch_get_into_multi_buffers( + keys, addrs, sizes, True) + for value in res: + if value < 0: + logger.error(f"Failed to get key {keys},res:{res}") + except Exception as e: + logger.error(f"Failed to get key {keys}. {e}") + + def put_batch(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]], block_ids: list[int]): + try: + config = ReplicateConfig() + config.preferred_segment = self.local_seg + config.prefer_alloc_in_same_node = True + res = self.store.batch_put_from_multi_buffers( + keys, addrs, sizes, config) + for value in res: + if value < 0: + logger.error(f"Failed to put key {keys},res:{res}") + except Exception as e: + logger.error(f"Failed to put key {keys},error:{e}") + + def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + expect_res = sum(size) + key_str = key.to_string() + try: + res = self.store.batch_get_into_ascend(key_str, addr, size) + if res[0] != expect_res: + logger.error(f"Failed to get key: [{key_str}] .") + except Exception: + logger.error(f"Failed to get key: [{key_str}] .") + return res + + def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + key_str = key.to_string() + try: + ret = self.store.batch_put_from_ascend(key_str, addr, size) + if ret[0] != 0: + logger.error(f"Failed to put key {key_str}.") + except Exception: + logger.error(f"Failed to put key {key_str}.") + + return ret + + def close(self): + self.store.close() + logger.info("Closed the mooncake store connection") diff --git a/vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py new file mode 100644 index 0000000..7733826 --- /dev/null +++ b/vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py @@ -0,0 +1,494 @@ +import threading +from typing import Any, Optional + +import torch +import vllm.envs as envs +import zmq +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.forward_context import ForwardContext +from vllm.utils import logger, make_zmq_socket +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + +from vllm_npu.distributed.mooncake.config_data import ( + LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) +from vllm_npu.distributed.mooncake.mooncake_engine import MooncakeEngine + + +class MooncakeConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "use_layerwise", False) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + self.sended_but_unfinished_reqs: set[str] = set() + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( + vllm_config, self.use_layerwise) + else: + self.connector_worker = MooncakeEngine( + vllm_config, + self.use_layerwise, + ) + + assert self.connector_worker is not None + if vllm_config.parallel_config.rank == 0 and self.kv_role != "kv_consumer": + self.lookup_server = MooncakeLookupServer( + self.connector_worker, vllm_config, self.use_layerwise) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._get_connector_metadata(), + MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._get_connector_metadata()) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeStoreConnector does not do layerwise saving.""" + if not self.use_layerwise: + return + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeStoreConnector does not save explicitly.""" + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + self.connector_worker.save_kv_layer(self._get_connector_metadata()) + + def wait_for_save(self): + """MooncakeStoreConnector does not save explicitly.""" + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + self.connector_worker.wait_layer_transfer_finish() + return + + self.connector_worker.wait_for_save(self._get_connector_metadata()) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + meta = self._get_connector_metadata() + done_sending, done_recving = self.connector_worker.get_finished() + sended_and_finished: set[str] = set() + for item in list(self.sended_but_unfinished_reqs): + if item not in meta.unfinished_request_ids: + sended_and_finished.add(item) + self.sended_but_unfinished_reqs.remove(item) + for item in done_sending: + if item in meta.unfinished_request_ids: + self.sended_but_unfinished_reqs.add(item) + else: + sended_and_finished.add(item) + + return sended_and_finished, done_recving + + +def get_zmq_rpc_path_mooncake( + vllm_config: Optional["VllmConfig"] = None, ) -> str: + base_url = envs.VLLM_RPC_BASE_PATH + # Default to 0 if not configured + rpc_port = 0 + if vllm_config is not None: + rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( + "mooncake_rpc_port", 0) + logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) + return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" + + +class MooncakeStoreConnectorV1Scheduler: + + def __init__(self, vllm_config: "VllmConfig", use_layerwise): + self.use_layerwise = use_layerwise + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.client = MooncakeLookupClient( + vllm_config) if self.kv_role != "kv_consumer" else None + self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "consumer_is_to_load", False) + self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "load_async", False) + # request_id -> (vllm cached tokes, mooncake cached tokens) + self.load_specs: dict[str, LoadSpec] = {} + self._block_size = vllm_config.cache_config.block_size + # request_id -> full_token_ids + self._request_trackers: dict[str, RequestTracker] = {} + # Whether to discard partial chunks + self._discard_partial_chunks = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", True)) + self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} + self._unfinished_request_ids: set[str] = set() + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Check for external KV cache hit. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + if self.kv_role == "kv_consumer" and not self.consumer_is_to_load: + return 0, False + + if self._discard_partial_chunks: + token_block_end = len(request.prompt_token_ids + ) // self._block_size * self._block_size + token_ids = torch.tensor( + request.prompt_token_ids[:token_block_end]) + else: + token_ids = torch.tensor(request.prompt_token_ids) + + num_external_hit_tokens = self.client.lookup( # type: ignore[union-attr] + token_ids) + + if num_external_hit_tokens == request.num_tokens: + num_external_hit_tokens -= 1 + + need_to_allocate = num_external_hit_tokens - num_computed_tokens + + logger.info( + "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", + request.request_id, + request.num_tokens, + num_external_hit_tokens, + need_to_allocate, + ) + + if need_to_allocate <= 0: + return 0, False + + self.load_specs[request.request_id] = LoadSpec( + vllm_cached_tokens=num_computed_tokens, + mooncake_cached_tokens=num_external_hit_tokens, + can_load=False, + ) + + return need_to_allocate, self.load_async + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after temporary buffer alloc. + + For SharedStorageConnector, update _request_needs_load + if the CacheManager this allocated blocks for us. + """ + local_block_ids = [] + if num_external_tokens > 0: + local_block_ids = blocks.get_block_ids()[0] + + self._unfinished_requests[request.request_id] = (request, + local_block_ids) + self._unfinished_request_ids.add(request.request_id) + if request.request_id not in self.load_specs: + # No KV tokens from external KV cache, return + return + + if num_external_tokens == 0: + # No need to load anything + self.load_specs[request.request_id].can_load = False + return + + assert ( + num_external_tokens > 0 and num_external_tokens + == self.load_specs[request.request_id].mooncake_cached_tokens - + self.load_specs[request.request_id].vllm_cached_tokens + ), (f"Mismatch in number of tokens: {num_external_tokens} vs " + f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " + f"{self.load_specs[request.request_id].vllm_cached_tokens}" + f" for request {request.request_id}") + + self.load_specs[request.request_id].can_load = True + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """Attach the connector metadata to the request object. + + This function should NOT modify other fields in the scheduler_output + except the `kv_connector_metadata` field. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + force_skip_save = self.kv_role == "kv_consumer" + + for finished_req_id in scheduler_output.finished_req_ids: + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + self._unfinished_request_ids.discard(finished_req_id) + + meta = MooncakeConnectorMetadata(self._unfinished_request_ids) + + for request in scheduler_output.scheduled_new_reqs: + # Right now, we only load KV for new requests + load_spec = self.load_specs.pop(request.req_id, None) + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id]) + request_tracker = RequestTracker.from_new_request( + request, num_tokens_to_compute) + self._request_trackers[request.req_id] = request_tracker + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else len( + request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + cached_reqs = scheduler_output.scheduled_cached_reqs + if isinstance(cached_reqs, list) and not force_skip_save: + for i, req in enumerate(cached_reqs): + request_tracker = self._request_trackers[req.req_id] + request_tracker.update(req.new_token_ids, req.new_block_ids) + last_chunk_tokens_num = ((len(req.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(req.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + elif not force_skip_save: + for i, req_id in enumerate(cached_reqs.req_ids): + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_tuple = self._unfinished_requests.get(req_id) + if req_tuple: + request = req_tuple[0] + num_current_tokens = len(request_tracker.token_ids) + new_token_ids = request.all_token_ids[ + num_current_tokens:num_current_tokens + num_new_tokens] + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached") + new_block_ids = cached_reqs.new_block_ids[i] + if not new_block_ids: + continue + request_tracker.update(new_token_ids, new_block_ids) + # decode not save + if len(request_tracker.token_ids) > len( + request.prompt_token_ids): + continue + + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + request_ids = [ + req.req_id for req in scheduler_output.scheduled_new_reqs + ] + for request_id, (request, + block_ids) in self._unfinished_requests.items(): + if request_id not in request_ids and request_id not in cached_reqs.req_ids: + load_spec = self.load_specs.pop(request_id, None) + if not load_spec: + continue + num_tokens_to_compute = load_spec.mooncake_cached_tokens + if (num_tokens_to_compute % self._block_size + != 0) and (num_tokens_to_compute + == len(request.prompt_token_ids) - 1): + num_tokens_to_compute = num_tokens_to_compute + 1 + request_tracker = RequestTracker( + req_id=request_id, + token_ids=request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=block_ids, + num_saved_tokens=0, + ) + + self._request_trackers[request_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=None, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + if self.kv_role == "kv_consumer": + return False, None + tracker = self._request_trackers.get(request.request_id) + if tracker is not None and tracker.num_saved_tokens <= 0: + return False, None + delay_free_blocks = len(block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(block_ids), request.request_id) + return delay_free_blocks, None + + +class MooncakeLookupClient: + + def __init__(self, vllm_config: "VllmConfig"): + self.encoder = MsgpackEncoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REQ, # type: ignore[attr-defined] + bind=False, + ) + + def lookup(self, token_ids: torch.Tensor) -> int: + request = self.encoder.encode(token_ids) + self.socket.send_multipart(request, copy=False) + resp = self.socket.recv() + result = int.from_bytes(resp, "big") + return result + + def close(self): + self.socket.close(linger=0) + + +class MooncakeLookupServer: + + def __init__( + self, + mooncake_engine: MooncakeEngine, + vllm_config: "VllmConfig", + use_layerwise: bool, + ): + self.decoder = MsgpackDecoder(torch.Tensor) + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.mooncake_engine = mooncake_engine + self.running = True + + def process_request(): + while self.running: + frames = self.socket.recv_multipart(copy=False) + token_ids = self.decoder.decode(frames) + result = self.mooncake_engine.lookup_scheduler( + token_ids, use_layerwise) + response = result.to_bytes(4, "big") + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! diff --git a/vllm_npu/distributed/mooncake/transfer_engine.py b/vllm_npu/distributed/mooncake/transfer_engine.py new file mode 100644 index 0000000..d4e172b --- /dev/null +++ b/vllm_npu/distributed/mooncake/transfer_engine.py @@ -0,0 +1,38 @@ +import ipaddress +import threading +from typing import Optional + +from mooncake.engine import TransferEngine # type: ignore + +_global_te = None +_global_te_lock = threading.Lock() + + +def get_global_te(hostname: str, device_name: Optional[str]): + try: + ip = ipaddress.ip_address(hostname) + if isinstance(ip, ipaddress.IPv6Address): + raise RuntimeError( + "The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6." + ) + except ValueError: + pass + + global _global_te + if _global_te is None: + with _global_te_lock: + # Double-Checked Locking + if _global_te is None: + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + transfer_engine = TransferEngine() + device_name = device_name if device_name is not None else "" + ret_value = transfer_engine.initialize(hostname, + "P2PHANDSHAKE", + "ascend", device_name) + if ret_value != 0: + raise RuntimeError( + f"TransferEngine initialization failed with ret_value: {ret_value}" + ) + _global_te = transfer_engine + return _global_te diff --git a/vllm_npu/distributed/mooncake_connector.py b/vllm_npu/distributed/mooncake_connector.py new file mode 100644 index 0000000..4f7eb5c --- /dev/null +++ b/vllm_npu/distributed/mooncake_connector.py @@ -0,0 +1,1263 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import hashlib +import math +import os +import queue +import random +import struct +import threading +import time +from collections import defaultdict, deque +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple + +import msgspec +import numpy as np +import numpy.typing as npt +import torch +import torch_npu +import zmq +from mooncake.engine import TransferEngine # type: ignore +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + get_tp_group) +from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config, init_ascend_config +from vllm_npu.distributed.mooncake.transfer_engine import get_global_te +from vllm_npu.distributed.utils import get_transfer_timeout_value + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" +DONE_RECVING_MSG = b"done_recving_msg" + + +class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): + engine_id: str + te_rpc_port: int + kv_caches_base_addr: list[int] + num_blocks: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + + +class KVCacheTaskTracker: + + def __init__(self): + super().__init__() + + self.done_task_lock = threading.Lock() + self.finished_requests: set[str] = set() + # Only used in prefill node. Tracks requests whose kv blocks freeing is + # intentionally delayed. Each entry is a tuple of (request_id, + # timestamp). If a request remains in this queue for too long, it will + # be force-freed. + self.record_finished_requests: set[str] = set() + self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() + + def add_not_transfer_request(self, request_id: str): + with self.done_task_lock: + self.finished_requests.add(request_id) + + def update_done_task_count(self, request_id: str): + with self.done_task_lock: + self.finished_requests.add(request_id) + if request_id in self.delayed_free_requests: + self._remove_delayed_requests(request_id) + else: + self.record_finished_requests.add(request_id) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + expired_requests = self._retrieve_expired_requests() + finished_requests.update(expired_requests) + self.finished_requests.clear() + return finished_requests + + def add_delayed_request(self, request_id: str, delay_start_time: float): + """Add a delayed free request.""" + with self.done_task_lock: + if request_id not in self.record_finished_requests: + self.delayed_free_requests[request_id] = delay_start_time + else: + self.record_finished_requests.discard(request_id) + + def _retrieve_expired_requests(self): + """Retrieve all expired delayed requests.""" + expired_requests: set[str] = set() + # Free delayed requests if they exceed the timeout + current_time = time.time() + while self.delayed_free_requests: + request_id = next(iter(self.delayed_free_requests)) + delay_start_time = self.delayed_free_requests[request_id] + if (current_time - delay_start_time + > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): + self.delayed_free_requests.popitem(last=False) + expired_requests.add(request_id) + logger.info("Force freed request: %s", request_id) + else: + break + return expired_requests + + def _remove_delayed_requests(self, request_id: str): + """Remove all delayed free requests matching the given request_id.""" + self.delayed_free_requests.pop(request_id) + + +class KVCacheSendingThread(threading.Thread): + + def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, + side_channel_host: str, side_channel_port: int, + metadata: MooncakeAgentMetadata, ready_event: threading.Event, + kv_caches: dict[str, Any]): + super().__init__(daemon=True, name="KVCacheSendingThread") + self.tp_rank = tp_rank + self.decode_tp_size = decode_tp_size + self.local_engine_id = local_engine_id + self.side_channel_host = side_channel_host + self.side_channel_port = side_channel_port + self.metadata = metadata + self.ready_event = ready_event + self.kv_caches = kv_caches + + self.task_tracker = KVCacheTaskTracker() + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + return self.task_tracker.get_and_clear_finished_requests() + + def add_not_transfer_request(self, request_id: str): + self.task_tracker.add_not_transfer_request(request_id) + + def add_delayed_request(self, request_id: str, delay_start_time: float): + return self.task_tracker.add_delayed_request(request_id, + delay_start_time) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(self.metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + # NOTE(rob): we need each rank to have a unique port. This hack to keeps + # us moving. We will switch when moving to etcd or where we have a + # single ZMQ socket in the scheduler. + handshake_port = self.side_channel_port + self.tp_rank + path = make_zmq_path("tcp", self.side_channel_host, handshake_port) + logger.info("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + self.ready_event.set() + decoder = msgspec.msgpack.Decoder(type=tuple) + while True: + try: + frames = sock.recv_multipart() + if len(frames) < 2: + logger.error("Invalid message format: %s", frames) + continue + + identity = frames[0] + payload = [f for f in frames[1:] if f != b""] + if len(payload) != 1: + logger.error("Invalid message format: %s", frames) + continue + + msg = decoder.decode(payload[0]) + if msg[0] == GET_META_MSG: + sock.send_multipart((identity, b"", encoded_data)) + elif msg[0] == DONE_RECVING_MSG: + logger.debug("Got DONE_RECVING_MSG for request %s", + msg[1]) + request_id = msg[1] + self.task_tracker.update_done_task_count(request_id) + # Acknowledge the request completion. + while True: + try: + # Send ACK to the sender. + sock.send_multipart( + (identity, b"", b"ACK"), + flags=zmq.NOBLOCK) # type: ignore + break + except zmq.Again: # type: ignore + # If the socket is not ready, retry sending. + logger.debug( + "Socket not ready, retrying to send ACK for " + "request %s", msg[1]) + time.sleep(0.01) + else: + logger.error( + "Connection listener got unexpected message %s", + msg) + except Exception as e: + logger.error("Connection listener got exception %s: %s", + type(e), e) + + +class KVCacheRecvingThread(threading.Thread): + + def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, + local_engine_id: str, local_handshake_port: int, + local_kv_caches_base_addr: list[int], block_len: list[int], + ready_event: threading.Event, vllm_config: VllmConfig, + kv_caches: dict[str, Any]): + super().__init__(daemon=True, name="KVCacheRecvingThread") + self.tp_rank = tp_rank + self.tp_size = tp_size + + self.local_engine_id = local_engine_id + self.local_handshake_port = local_handshake_port + self.engine = engine + self.ready_event = ready_event + + self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ + defaultdict(dict) + self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \ + local_kv_caches_base_addr + self.remote_te_port: dict[str, dict[int, int]] = \ + defaultdict(dict) + self.block_len = block_len + # TODO(jianzs): find a better way to detect MLA. + self.use_mla = len(block_len) == 2 + self.use_sparse = len(block_len) == 3 + + self.request_queue: queue.Queue[Any] = queue.Queue() + self.executor = ThreadPoolExecutor(max_workers=32) + + self.task_tracker = KVCacheTaskTracker() + + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + self.remote_sockets_lock = threading.Lock() + self.remote_sockets: dict[ # type: ignore + str, deque[zmq.Socket]] = defaultdict( # type: ignore + deque) + self.remote_poller = zmq.Poller() # type: ignore + self.timeout = 1.0 # seconds + + self.vllm_config = vllm_config + self.model_config = self.vllm_config.model_config + self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads + self.kv_caches = kv_caches + + def add_request(self, request_id: str, local_block_ids: list[int], + remote_block_ids: list[int], remote_engine_id: str, + remote_host: str, remote_handshake_port: int, offset: int, + num_need_pulls: int): + """Add a new request to the queue for processing.""" + logger.debug(f"Adding request {request_id} to the queue.") + self.request_queue.put({ + "request_id": request_id, + "local_block_ids": local_block_ids, + "remote_block_ids": remote_block_ids, + "remote_engine_id": remote_engine_id, + "remote_host": remote_host, + "remote_handshake_port": remote_handshake_port, + "offset": offset, + "num_need_pulls": num_need_pulls + }) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + return self.task_tracker.get_and_clear_finished_requests() + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + request_id = req_meta["request_id"] + remote_host = req_meta["remote_host"] + remote_handshake_port = req_meta["remote_handshake_port"] + offset = req_meta["offset"] + num_need_pulls = req_meta["num_need_pulls"] + + try: + logger.debug( + f"Starting to transfer KV cache for request {request_id}.") + self._transfer_kv_cache(req_meta) + logger.debug( + f"Finished transferring KV cache for request {request_id}.") + except Exception as e: + logger.error("Failed to transfer KV cache for request " + f"{request_id}: {e}") + finally: + # Always send the done signal to the remote host to ensure proper + # resource cleanup. Failing to do so may cause a memory leak on the + # remote host. + self._send_done_recv_signal(request_id, remote_host, + remote_handshake_port) + if offset == num_need_pulls - 1: + self.task_tracker.update_done_task_count(request_id) + self.request_queue.task_done() + + def _transfer_kv_cache(self, req_meta: dict[str, Any]): + """Handle a KV cache transfer request.""" + request_id = req_meta["request_id"] + remote_block_ids = req_meta["remote_block_ids"] + local_block_ids = req_meta["local_block_ids"] + remote_engine_id = req_meta["remote_engine_id"] + remote_host = req_meta["remote_host"] + remote_handshake_port = req_meta["remote_handshake_port"] + offset = req_meta["offset"] + self.num_need_pulls = req_meta["num_need_pulls"] + + # Full prefix cache hit: do not need to read remote blocks, just notify + # P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + return + + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Check if we have the remote metadata cached. + if remote_engine_id not in self.kv_caches_base_addr or \ + remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: + self._get_remote_metadata(remote_host, remote_handshake_port) + + if self.num_need_pulls == 1: + grouped_remote_block_ids, grouped_local_block_ids = \ + group_concurrent_contiguous(remote_block_ids, local_block_ids) + else: + remote_block_ids = list(map(lambda x: [x], remote_block_ids)) + local_block_ids = list(map(lambda x: [x], local_block_ids)) + grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids + num_transfer_groups = len(grouped_remote_block_ids) + + remote_kv_caches_base_addrs = \ + self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] + local_kv_caches_base_addrs = \ + self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port] + remote_transfer_port = self.remote_te_port[remote_engine_id][ + remote_handshake_port] + num_blocks = len(local_block_ids) + session_id = f"{remote_host}:{remote_transfer_port}" + + req_start_time = time.perf_counter() + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): + if self.use_mla: + block_len = (self.block_len[k % 2]) + elif self.use_sparse: + block_len = (self.block_len[k % 3]) + else: + block_len = (self.block_len[0]) + inner_block_len = block_len // self.num_need_pulls + for remote_block_id, local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids): + src = src_layer_base_addr + local_block_id[ + 0] * block_len + offset * inner_block_len + dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len + length = inner_block_len * len(local_block_id) + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + + ret = self.engine.batch_transfer_sync_read(session_id, src_list, + dst_list, length_list) + if ret < 0: + logger.error("Mooncake transfer failed for request %s", + req_meta["request_id"]) + raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + + req_end_time = time.perf_counter() + req_transfer_elapsed = (req_end_time - req_start_time) * 1000 + logger.info( + "KV cache transfer for request %s took %.2f ms (%d groups," + " %d blocks). local_ip %s local_device_id %s remote_session_id %s", + request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, + get_ip(), self.tp_rank, session_id) + if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1: + self._cat_kv_cache(grouped_local_block_ids) + + def _cat_kv_cache(self, block_ids: list[list[int]]): + # Get necessary parameters + k_cache = list(self.kv_caches.values())[0][0] + kv_shape = k_cache.shape + dtype = k_cache.dtype + device = k_cache.device + head_dim = self.model_config.hf_config.head_dim + block_size = self.vllm_config.cache_config.block_size + num_kv_head = max( + self.model_config.hf_config.num_key_value_heads // self.tp_size, 1) + + flat_block_ids = [item for sublist in block_ids for item in sublist] + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) + num_blocks = len(flat_block_ids) + block_len = num_blocks * block_size + + # Create device tensors for copy operations + block_table = block_ids_tensor.view(1, -1).to(device=device) + block_len_tensor = torch.tensor([block_len], + dtype=torch.int32).to(device=device) + seq_start_tensor = torch.tensor([0], + dtype=torch.int32).to(device=device) + + # Initialize buffers + k_buffer = torch.empty(block_len, + num_kv_head, + head_dim, + dtype=dtype, + device=device) + v_buffer = torch.empty(block_len, + num_kv_head, + head_dim, + dtype=dtype, + device=device) + + # Create slot mapping for reshape operations + block_offsets = torch.arange(0, block_size, dtype=torch.int32) + slot_mapping = (block_offsets.reshape( + (1, block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * block_size) + slot_mapping = slot_mapping.flatten().to(device=device) + + # Process each layer in the KV cache + for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): + if len( + k_cache_layer.shape + ) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim] + k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1], + num_kv_head, head_dim) + v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1], + num_kv_head, head_dim) + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer, + ) + + # Transpose KV cache + k_buffer = self._transpose_kv_cache_between_head( + k_buffer, num_blocks, block_size, block_len, num_kv_head) + v_buffer = self._transpose_kv_cache_between_head( + v_buffer, num_blocks, block_size, block_len, num_kv_head) + + # Reshape and cache the processed buffers + torch_npu._npu_reshape_and_cache( + key=k_buffer, + value=v_buffer, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping, + ) + + # Clean up buffers + del k_buffer, v_buffer + + def _transpose_kv_cache_between_head(self, buffer: torch.Tensor, + num_blocks: int, block_size: int, + block_len: int, + num_kv_head: int) -> torch.Tensor: + buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1) + buffer.transpose_(1, 2) + return buffer.contiguous().view(block_len, num_kv_head, -1) + + def _get_remote_metadata(self, remote_host: str, + remote_handshake_port: int) -> None: + """Get the metadata from the remote host.""" + sock: Optional[zmq.Socket] = None # type: ignore + try: + sock = self._get_remote_socket(remote_host, remote_handshake_port) + ensure_zmq_send(sock, self.encoder.encode((GET_META_MSG, ""))) + metadata_bytes = ensure_zmq_recv(sock, self.remote_poller) + agent_meta = self.decoder.decode(metadata_bytes) + engine_id = agent_meta.engine_id + assert engine_id != self.local_engine_id, ( + f"Conflict engine id {engine_id} with local engine id " + f"{self.local_engine_id}.") + self.kv_caches_base_addr[engine_id][remote_handshake_port] = \ + agent_meta.kv_caches_base_addr + self.remote_te_port[engine_id][remote_handshake_port] = \ + agent_meta.te_rpc_port + finally: + if sock is not None: + self._return_remote_socket(sock, remote_host, + remote_handshake_port) + logger.debug("Returned socket to pool for %s:%d", remote_host, + remote_handshake_port) + + def _send_done_recv_signal(self, request_id: str, remote_host: str, + remote_handshake_port: int): + logger.debug("Sending done recving signal for request %s to %s:%d", + request_id, remote_host, remote_handshake_port) + sock: Optional[zmq.Socket] = None # type: ignore + try: + sock = self._get_remote_socket(remote_host, remote_handshake_port) + data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id)) + ensure_zmq_send(sock, data_bytes) + resp = ensure_zmq_recv(sock, + self.remote_poller, + timeout=self.timeout) + logger.debug( + f"Received response for request {request_id}: {resp.decode('utf-8')}" + ) + if resp != b"ACK": + logger.error("Failed to receive ACK for request %s from %s:%d", + request_id, remote_host, remote_handshake_port) + raise RuntimeError( + f"Failed to receive ACK, resp: {resp.decode('utf-8')}") + finally: + if sock is not None: + self._return_remote_socket(sock, remote_host, + remote_handshake_port) + logger.debug("Returned socket to pool for %s:%d", remote_host, + remote_handshake_port) + + def _get_remote_socket( + self, remote_host: str, + remote_handshake_port: int) -> zmq.Socket: # type: ignore + """Get a socket to the remote host.""" + remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) + with self.remote_sockets_lock: + if self.remote_sockets[remote_path]: + return self.remote_sockets[remote_path].popleft() + + ctx = zmq.Context() # type: ignore + sock = make_zmq_socket( + ctx=ctx, + path=remote_path, + socket_type=zmq.REQ, # type: ignore + bind=False) + sock.setsockopt( + zmq.SNDTIMEO, # type: ignore + int(self.timeout * 1000)) + self.remote_poller.register(sock, zmq.POLLIN) # type: ignore + return sock + + def _return_remote_socket( + self, + sock: zmq.Socket, # type: ignore + remote_host: str, + remote_handshake_port: int) -> None: + """Return the remote socket to the pool.""" + remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) + with self.remote_sockets_lock: + self.remote_sockets[remote_path].append(sock) + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + self.requests_to_send: dict[str, float] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + ) + + +class MooncakeConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ + MooncakeConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[MooncakeConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MooncakeConnectorWorker( + vllm_config, str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """MooncakeConnector does not save explicitly.""" + pass + + +class MooncakeConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + init_ascend_config(vllm_config) + self.ascend_config = get_ascend_config() + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + self.local_ip = get_ip() + logger.info("Initializing Mooncake Scheduler %s", engine_id) + + self.side_channel_host = get_ip() + self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ + vllm_config.parallel_config.data_parallel_size + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank * + vllm_config.parallel_config.tensor_parallel_size) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + meta.requests_to_send = self._reqs_need_send + self._reqs_need_send = {} + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + computed_block_ids = block_ids + delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(computed_block_ids), request.request_id) + self._reqs_need_send[request.request_id] = time.time() + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + last_token_id=request.output_token_ids[-1], + ) + + +class MooncakeConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self._get_prefill_decode_size(vllm_config) + os.environ["ASCEND_TRANSFER_TIMEOUT"] = str( + get_transfer_timeout_value()) + if self._prefill_tp_size < self._decode_tp_size: + raise ValueError( + f"prefill_tp_size: {self._prefill_tp_size} must be greater than" + f" or equal to the decode_tp_size: {self._decode_tp_size}") + + # Metadata. + self.vllm_config = vllm_config + self.ascend_config = get_ascend_config() + self.engine_id = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_group = get_tp_group() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.dp_size = vllm_config.parallel_config.data_parallel_size_local + self.kv_caches: dict[str, torch.Tensor] = {} + self.side_channel_host = get_ip() + self.max_device_id = self.tp_size * self.dp_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank * + vllm_config.parallel_config.tensor_parallel_size) + self.handshake_port = self.side_channel_port + self.tp_rank + self.sockets: dict = {} + + # get tp device id + # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 + # introducing some changes + device_ids_str = envs_ascend.PHYSICAL_DEVICES + if device_ids_str is None: + device_ids = list( + range(self.dp_rank * self.tp_size, + (self.dp_rank + 1) * self.tp_size)) + else: + device_ids = list(map(int, device_ids_str.split(','))) + start_index = self.dp_rank * self.tp_size + end_index = start_index + self.tp_size + if len(device_ids) < end_index: + raise ValueError( + f"Not enough physical devices available for DP rank {self.dp_rank}. " + f"Expected at least {end_index} devices, but found {len(device_ids)} " + "in PHYSICAL_DEVICES.") + device_ids = device_ids[start_index:end_index] + assert len(device_ids) > self.tp_rank # type: ignore + self.device_id = device_ids[self.tp_rank] # type: ignore + + if vllm_config.kv_transfer_config.get_from_extra_config( + 'use_ascend_direct', True): + hostname = self.side_channel_host + else: + hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" + logger.info("Initializing Mooncake work %s", engine_id) + self.engine = get_global_te(hostname, device_name=None) + self.te_rpc_port = self.engine.get_rpc_port() + + # Background thread for sending or receiving KV caches. + self.kv_send_thread: Optional[KVCacheSendingThread] = None + self.kv_recv_thread: Optional[KVCacheRecvingThread] = None + + # kv_transfer variables + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + if self.vllm_config.model_config.is_deepseek_mla: + self.num_need_pulls = 1 + else: + num_d_block_heads = max(1, + self.num_key_value_heads // self.tp_size) + num_p_block_heads = max( + 1, self.num_key_value_heads // self._prefill_tp_size) + self.num_need_pulls = num_d_block_heads // num_p_block_heads + + def _get_prefill_decode_size(self, vllm_config: VllmConfig): + # get prefill tp and dp size from extra config + prefill_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {}) + + assert "tp_size" in prefill_parallel_config.keys() + self._prefill_tp_size = prefill_parallel_config["tp_size"] + + assert "dp_size" in prefill_parallel_config.keys() + self._prefill_dp_size = prefill_parallel_config["dp_size"] + + # get decode tp and dp size from extra config + decode_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {}) + assert "tp_size" in decode_parallel_config.keys() + self._decode_tp_size = decode_parallel_config["tp_size"] + assert "dp_size" in decode_parallel_config.keys() + self._decode_dp_size = decode_parallel_config["dp_size"] + + def _initialize( + self, + hostname: str, + device_name: Optional[str], + ) -> None: + """Initialize the mooncake instance.""" + device_name = device_name if device_name is not None else "" + ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend", + device_name) + if ret_value != 0: + raise RuntimeError( + f"Mooncake initialization failed with ret_value: {ret_value}") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data.""" + + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + + # TODO(tms): Find a more robust way to detect and handle MLA + self.use_mla = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) and len( + first_kv_cache_tuple) == 2 + self.use_sparse = len(first_kv_cache_tuple) == 3 + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + elif self.use_sparse: + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe), + first_kv_cache[2].element_size() * math.prod(block_shape_k) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", + self.num_blocks, block_shape_norm, block_shape_pe, + block_shape_k) + else: + # eager:[num_block, block_size, num_head, hidden_dim] + # torchair:[num_block, block_size, num_head*hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = len( + first_kv_cache.shape + ) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + logger.info( + "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", + self.use_mla, self.use_sparse, first_kv_cache.shape) + + self.kv_caches = kv_caches + kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 2] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + elif self.use_sparse: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 3] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + else: + cache_list = [ + cache_or_caches + ] if self.use_mla or self.use_sparse else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[0] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + + # After KV Caches registered, start the sending or receiving thread. + metadata = MooncakeAgentMetadata( + engine_id=self.engine_id, + te_rpc_port=self.te_rpc_port, + kv_caches_base_addr=kv_caches_base_addr, + num_blocks=self.num_blocks, + ) + + ready_event = threading.Event() + if self.kv_role == 'kv_producer': + self.kv_send_thread = KVCacheSendingThread( + self.tp_rank, self._decode_tp_size, self.engine_id, + self.side_channel_host, self.side_channel_port, metadata, + ready_event, self.kv_caches) + self.kv_send_thread.start() + else: + self.kv_recv_thread = KVCacheRecvingThread( + self.tp_rank, self.tp_size, self.engine, self.engine_id, + self.handshake_port, kv_caches_base_addr, self.block_len, + ready_event, self.vllm_config, self.kv_caches) + self.kv_recv_thread.start() + ready_event.wait() + + def _register(self, ptr, length): + logger.debug( + "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " + "block_lens=%s", ptr, length, self.num_blocks, self.block_len) + ret_value = self.engine.register_memory(ptr, length) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed.") + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_producer' else set()) + done_recving = ( + self.kv_recv_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_consumer' else set()) + if self.tp_rank == 0: + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d", len(done_sending), len(done_recving)) + return done_sending, done_recving + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + """Start loading KV blocks from remote engine.""" + for req_id, meta in metadata.requests.items(): + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + + choosen_rank_list = self._get_remote_tp_rank(req_id) + remote_handshake_port_list = [ + x + meta.remote_port for x in choosen_rank_list + ] + for i in range(self.num_need_pulls): + assert self.kv_recv_thread is not None + self.kv_recv_thread.add_request( + request_id=req_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port_list[i], + offset=i, + num_need_pulls=self.num_need_pulls) + + if self.kv_send_thread is not None: + for req_id, delay_start_time in metadata.requests_to_send.items(): + if self.tp_rank in self._prefill_get_remote_tp_rank(req_id): + self.kv_send_thread.add_delayed_request( + req_id, delay_start_time) + else: + self.kv_send_thread.add_not_transfer_request(req_id) + + def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]: + return sum(self._get_remote_tp_ranks_for_req(req_id), []) + + def _get_remote_tp_rank(self, req_id: str) -> List[int]: + return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] + + def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]: + if self._prefill_tp_size == self._decode_tp_size: + result = list(map(lambda x: [x], range(self._prefill_tp_size))) + return result + + seed = string_to_int64_hash(req_id) + rand = random.Random(seed) + sampled_nums = [] + ori_data = np.arange(self._prefill_tp_size) + # random split prefill tp list + if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: + # use deepseek mla, num_key_value_heads == 128, but consider as 1 + if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: + num_kv_head = 1 + else: + num_kv_head = self.num_key_value_heads + num_groups = len(ori_data) // num_kv_head + ori_data = ori_data.reshape(-1, num_groups) + rand_group_index = rand.sample(range(num_groups), \ + max(self._decode_tp_size // num_kv_head, 1)) # random choose a group + + choosen_group = ori_data[:, [rand_group_index]] + flattened = choosen_group.reshape(-1).tolist() + sampled_nums = [ + flattened[i:i + self.num_need_pulls] + for i in range(0, len(flattened), self.num_need_pulls) + ] + + # non-random split + else: + group_size = self._prefill_tp_size // self._decode_tp_size + for i in range(self._decode_tp_size): + ori_data_slice = ori_data[i * group_size:(i + 1) * group_size] + sampled_nums.append(ori_data_slice.tolist()) + return sampled_nums + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None # type: ignore + try: + ctx = zmq.Context() # type: ignore + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) # type: ignore + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +def group_concurrent_contiguous( + src: List[int], dst: List[int] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + """Vectorised NumPy implementation.""" + src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) + dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64) + + if src_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) + | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups + + +def string_to_int64_hash(input_str): + """ + Hash the string using SHA-256 and convert it into an int64 integer. + """ + hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest() + trunked_bytes = hashed_bytes[:8] + uint64_value = struct.unpack(" 0: + logger.warning( + f"Send failed: {e}, retrying... ({retries_left} " + "attempts left)") + time.sleep(0.1) + else: + logger.error(f"Send failed after all retries: {e}") + raise RuntimeError(f"Failed to send data after {max_retries} " + f"retries: {e}") + + +def ensure_zmq_recv( + socket: zmq.Socket, # type: ignore + poller: zmq.Poller, # type: ignore + timeout: float = 1.0, + max_retries: int = 3) -> bytes: + retries_left = max_retries + while True: + try: + if dict(poller.poll(int(timeout * 1000))): # milliseconds + data = socket.recv() + return data + else: + raise zmq.ZMQError("Receive timeout") # type: ignore + except zmq.ZMQError as e: # type: ignore + retries_left -= 1 + if retries_left > 0: + logger.warning(f"Receive failed: {e}, retrying... " + f"({retries_left} attempts left)") + time.sleep(0.1) + else: + logger.error(f"Receive failed after all retries: {e}") + raise RuntimeError( + f"Failed to receive data after {max_retries} " + f"retries: {e}") diff --git a/vllm_npu/distributed/mooncake_layerwise_connector.py b/vllm_npu/distributed/mooncake_layerwise_connector.py new file mode 100644 index 0000000..be36503 --- /dev/null +++ b/vllm_npu/distributed/mooncake_layerwise_connector.py @@ -0,0 +1,1153 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import copy +import hashlib +import math +import os +import queue +import struct +import threading +import time +from collections import defaultdict, deque +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple + +import httpx +import msgspec +import numpy as np +import numpy.typing as npt +import torch +import torch_npu +import zmq +from mooncake.engine import TransferEngine # type: ignore +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + get_tp_group, get_world_group) +from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket +from vllm.v1.core.sched.output import SchedulerOutput + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.distributed.utils import (align_memory, + get_transfer_timeout_value, + kv_alltoall_and_rearrange) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" +DONE_SENDING_MSG = b"done_sending_msg" + + +class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): + te_rpc_port: int + kv_caches_base_addr: list[int] + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + token_ids: list[int] + # Not None if layer-wise is disabled + remote_block_ids: list[int] + remote_engine_id: Optional[str] + remote_host: Optional[str] + remote_port: Optional[int] + remote_te_rpc_port: Optional[int] + remote_kv_caches_base_addr: Optional[list[int]] + metaserver: Optional[str] + + +class KVCacheSendingLayerThread(threading.Thread): + + def __init__(self, + engine: TransferEngine, + total_layers: int, + ready_event: threading.Event, + tp_rank: int, + pd_head_ratio: int, + num_head_replica: int, + kv_cache_base_addr: list[int], + use_mla: bool, + block_len: list[int], + first_kv_cache: torch.Tensor, + callback_func: Callable[..., None] = lambda x: None): + super().__init__(daemon=True, name="KVCacheSendingLayerThread") + self.engine = engine + self.tp_rank = tp_rank + self.pd_head_ratio = pd_head_ratio + self.num_head_replica = num_head_replica + self.kv_caches_base_addr = kv_cache_base_addr + self.total_layers = total_layers + self.use_mla = use_mla + self.block_len = block_len + self.model_stream = torch_npu.npu.current_stream() + self.current_layer = -1 + + if self.pd_head_ratio > 1: + # regesit kv buffer for tp inequal + alignment = 2 * 1024 * 1024 + self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.k_buffer = align_memory( + self.k_buffer, alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.v_buffer = align_memory( + self.v_buffer, alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + + for tensor in (self.k_buffer, self.v_buffer): + assert tensor.data_ptr( + ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory( + tensor.data_ptr(), tensor.numel()) + logger.info( + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") + + self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, + torch.Tensor]]() + + self.ready_event = ready_event + self.callback_func = callback_func + + def run(self): + local_rank = get_world_group().local_rank + device = torch.device(f"npu:{local_rank}") + torch.npu.set_device(device) + self.ready_event.set() + while True: + req_id, req_meta, layer_index, key, value = self.send_queue.get() + self._handle_request(req_id, req_meta, layer_index, key, value) + + def _handle_request(self, req_id, req_meta, layer_index, key, value): + try: + logger.debug( + f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." + ) + self._transfer_kv_cache(req_id, req_meta, layer_index, key, value) + logger.debug( + f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." + ) + except Exception as e: + logger.error("Failed to transfer KV cache for request " + f"{req_id}: {e}") + + def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): + # send kv layer to remote + if len(req_meta.local_block_ids) == 0: + return + # not need to send kv cache + if self.tp_rank % self.num_head_replica != 0: + return + + remote_host = req_meta.remote_host + remote_block_ids = req_meta.remote_block_ids + remote_te_port = req_meta.remote_te_rpc_port + remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr + local_kv_base_addr = self.kv_caches_base_addr + local_block_ids = req_meta.local_block_ids + + if self.pd_head_ratio == 1: + layer_local_kv_base_addr = [ + local_kv_base_addr[i] + for i in [2 * layer_index, 2 * layer_index + 1] + ] + layer_remote_kv_base_addr = [ + remote_kv_base_addrs[i] + for i in [2 * layer_index, 2 * layer_index + 1] + ] + grouped_remote_block_ids, grouped_local_block_ids = \ + group_concurrent_contiguous(remote_block_ids, local_block_ids) + + session_id = f"{remote_host}:{remote_te_port}" + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): + block_len = self.block_len[ + k % 2] if self.use_mla else self.block_len[0] + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids): + src = src_layer_base_addr + group_local_block_id[ + 0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[ + 0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + if self.current_layer != layer_index: + self.current_layer = layer_index + self.model_stream.synchronize() + ret = self.engine.batch_transfer_sync_write( + session_id, src_list, dst_list, length_list) + if ret < 0: + logger.error("Mooncake transfer failed for request %s", req_id) + raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + else: + key = key.view(-1, key.shape[-1]) + value = value.view(-1, key.shape[-1]) + self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] -> + self.v_buffer[:value.shape[0]].copy_(value) + + layer_local_kv_base_addr = [ + self.k_buffer.data_ptr(), + self.v_buffer.data_ptr() + ] + + layer_remote_kv_base_addr = [ + remote_kv_base_addrs[i] + for i in [2 * layer_index, 2 * layer_index + 1] + ] + + grouped_remote_block_ids, _ = group_concurrent_contiguous( + remote_block_ids) + + session_id = f"{remote_host}:{remote_te_port}" + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): + src_layer_addr = src_layer_base_addr + for group_remote_block_id in grouped_remote_block_ids: + block_len = self.block_len[0] + remote_block_len = self.block_len[0] * self.pd_head_ratio + src_list.append(src_layer_addr) + + if src_layer_addr + len( + group_remote_block_id + ) * block_len > src_layer_base_addr + key.numel( + ) * key.element_size(): + length = src_layer_base_addr + key.numel( + ) * key.element_size() - src_layer_addr + else: + length = len(group_remote_block_id) * block_len + length_list.append(length) + + dst_list.append(dst_layer_base_addr + + group_remote_block_id[0] * + remote_block_len + length * + ((self.tp_rank // self.num_head_replica) % + self.pd_head_ratio)) + src_layer_addr += length + self.model_stream.synchronize() + ret = self.engine.batch_transfer_sync_write( + session_id, src_list, dst_list, length_list) + if ret < 0: + logger.error("Mooncake transfer failed for request %s", req_id) + raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + + if layer_index == (self.total_layers - 1): + self.callback_func(req_id, req_meta) + + +class KVCacheRecvingLayerThread(threading.Thread): + + def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, + pd_head_ratio: int, local_engine_id: str, + metadata: MooncakeAgentMetadata, + ready_event: threading.Event): + super().__init__(daemon=True, name="KVCacheRecvingLayerThread") + self.tp_rank = tp_rank + self.tp_size = tp_size + self.pd_head_ratio = pd_head_ratio + self.local_engine_id = local_engine_id + self.side_channel_host = get_ip() + self.side_channel_port = side_channel_port + self.lock = threading.Lock() + self.done_requests = set[str]() + self.task_tracker = dict[str, int]() + self.ready_event = ready_event + self.metadata = metadata + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.lock: + finished_requests = self.done_requests + self.done_requests = set() + return finished_requests + + def update_task(self, req_id): + with self.lock: + self.task_tracker[req_id] += 1 + if self.task_tracker[req_id] == self.pd_head_ratio: + self.task_tracker.pop(req_id) + self.done_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + handshake_port = self.side_channel_port + self.tp_rank + path = make_zmq_path("tcp", self.side_channel_host, handshake_port) + logger.info("Starting listening on path: %s", path) + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(self.metadata) + with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + self.ready_event.set() + decoder = msgspec.msgpack.Decoder(type=tuple) + while True: + try: + frames = sock.recv_multipart() + if len(frames) < 2: + logger.error("Invalid message format: %s", frames) + continue + + identity = frames[0] + payload = [f for f in frames[1:] if f != b""] + if len(payload) != 1: + logger.error("Invalid message format: %s", frames) + continue + + msg = decoder.decode(payload[0]) + if msg[0] == GET_META_MSG: + logger.info("Got GET META INFO for request %s", msg[0]) + sock.send_multipart((identity, b"", encoded_data)) + elif msg[0] == DONE_SENDING_MSG: + logger.debug("Got DONE_RECVING_MSG for request %s", + msg[1]) + request_id = msg[1] + self.update_task(request_id) + sock.send_multipart((identity, b"", b"ACK")) + else: + logger.error( + "Connection listener got unexpected message %s", + msg) + except Exception as e: + logger.error("Failed to decode message: %s", e) + + +class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req(self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + token_ids: Optional[list[int]] = None): + self.requests[request_id] = ReqMeta( + token_ids=token_ids or [], + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params.get("remote_block_ids", []), + remote_engine_id=kv_transfer_params.get("remote_engine_id", None), + remote_host=kv_transfer_params.get("remote_host", None), + remote_port=kv_transfer_params.get("remote_port", None), + remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port", + None), + remote_kv_caches_base_addr=kv_transfer_params.get( + "remote_kv_caches_base_addr", None), + metaserver=kv_transfer_params.get("metaserver", None), + ) + + +class MooncakeLayerwiseConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + self._connector_metadata = MooncakeLayerwiseConnectorMetadata() + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[MooncakeLayerwiseConnectorScheduler] = \ + MooncakeLayerwiseConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[ + MooncakeLayerwiseConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MooncakeLayerwiseConnectorWorker( + vllm_config, str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + MooncakeLayerwiseConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeLayerwiseConnector does not do layerwise saving.""" + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + MooncakeLayerwiseConnectorMetadata) + self.connector_worker.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeLayerwiseConnector does not save explicitly.""" + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + MooncakeLayerwiseConnectorMetadata) + self.connector_worker.save_kv_layer(layer_name, kv_layer, + attn_metadata, + self._connector_metadata) + + def wait_for_save(self): + """MooncakeLayerwiseConnector does not save explicitly.""" + pass + + +class MooncakeLayerwiseConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing Mooncake Scheduler %s", engine_id) + + self.side_channel_host = get_ip() + self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ + vllm_config.parallel_config.data_parallel_size + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank * + vllm_config.parallel_config.tensor_parallel_size) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int], + list[int]]] = {} + self._reqs_need_send_layerwise: dict[str, tuple[ + int, list[int], + Request]] = {} # req_id, (len(prompt), local_block_ids, request) + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeLayerwiseConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "MooncakeLayerwiseConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, + [], #request._all_token_ids, + local_block_ids) + + params["do_remote_prefill"] = False + + # Layerwise prefiller add request need send + if params is not None and params.get("do_remote_decode"): + local_block_ids = (blocks.get_block_ids()[0]) + self._reqs_need_send_layerwise[request.request_id] = (len( + request.all_token_ids), local_block_ids, request) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeLayerwiseConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, token_ids, + block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + cached_reqs = scheduler_output.scheduled_cached_reqs + new_reqs = scheduler_output.scheduled_new_reqs + for req_id, new_blocks in zip(cached_reqs.req_ids, + cached_reqs.new_block_ids): + if req_id in self._reqs_need_send_layerwise and new_blocks is not None: + total_tokens, block_ids, req = self._reqs_need_send_layerwise[ + req_id] + block_ids.extend(new_blocks[0]) + + computed_tokens = dict( + list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) + for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( + ): + if req_id in self._reqs_need_send_layerwise: + total_tokens, block_ids, req = self._reqs_need_send_layerwise[ + req_id] + current_tokens = computed_tokens.get(req_id, + 0) + scheduled_tokens + if current_tokens == total_tokens: + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=[]) + self._reqs_need_send_layerwise.pop(req_id) + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + # layer_wise push, not need delay_free_blocks + return False, None + + +class MooncakeLayerwiseConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self._get_prefill_decode_size(vllm_config) + os.environ["ASCEND_TRANSFER_TIMEOUT"] = str( + get_transfer_timeout_value()) + if self._prefill_tp_size < self._decode_tp_size: + raise ValueError( + f"prefill_tp_size: {self._prefill_tp_size} must be greater than" + f" or equal to the decode_tp_size: {self._decode_tp_size}") + + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + logger.info("Initializing Mooncake work %s", engine_id) + self.engine = TransferEngine() + + # Metadata. + self.vllm_config = vllm_config + self.local_engine_id: str = " " + self.engine_id = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_group = get_tp_group() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.dp_size = vllm_config.parallel_config.data_parallel_size_local + self.kv_caches: dict[str, torch.Tensor] = {} + self.side_channel_host = get_ip() + self.max_device_id = self.tp_size * self.dp_size + self.total_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + + self.executor = ThreadPoolExecutor(32) + self.metaserver_client = httpx.Client( + limits=httpx.Limits(max_connections=100000), + timeout=None) if self.tp_rank == 0 else None + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank * + vllm_config.parallel_config.tensor_parallel_size) + self.handshake_port = self.side_channel_port + self.tp_rank + self.sockets: dict = {} + + # get tp device id + # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 + # introducing some changes + device_ids_str = envs_ascend.PHYSICAL_DEVICES + if device_ids_str is None: + device_ids = list( + range(self.dp_rank * self.tp_size, + (self.dp_rank + 1) * self.tp_size)) + else: + device_ids = list(map(int, device_ids_str.split(','))) + start_index = self.dp_rank * self.tp_size + end_index = start_index + self.tp_size + if len(device_ids) < end_index: + raise ValueError( + f"Not enough physical devices available for DP rank {self.dp_rank}. " + f"Expected at least {end_index} devices, but found {len(device_ids)} " + "in PHYSICAL_DEVICES.") + device_ids = device_ids[start_index:end_index] + assert len(device_ids) > self.tp_rank # type: ignore + self.device_id = device_ids[self.tp_rank] # type: ignore + + if vllm_config.kv_transfer_config.get_from_extra_config( + 'use_ascend_direct', True): + hostname = self.side_channel_host + else: + hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" + self._initialize(hostname=hostname, device_name=None) + self.te_rpc_port = self.engine.get_rpc_port() + + # Background thread for sending or receiving KV caches. + self.kv_recv_layer_thread: Optional[KVCacheRecvingLayerThread] = None + self.kv_send_layer_thread: Optional[KVCacheSendingLayerThread] = None + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.kv_caches_base_addr: list[int] = [] + + self.pd_tp_ratio = get_ascend_config().pd_tp_ratio + self.pd_head_ratio = get_ascend_config().pd_head_ratio + self.num_head_replica = get_ascend_config().num_head_replica + + self.first_kv_cache = None + self.remote_poller = zmq.Poller() # type: ignore + self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + self.encoder = msgspec.msgpack.Encoder() + + self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ + defaultdict(dict) + self.remote_te_port: dict[str, dict[int, int]] = \ + defaultdict(dict) + self.remote_sockets_lock = threading.Lock() + self.remote_sockets: dict[ # type: ignore + str, deque[zmq.Socket]] = defaultdict( # type: ignore + deque) + self.remote_poller = zmq.Poller() # type: ignore + self.timeout = 1.0 # seconds + + def _get_prefill_decode_size(self, vllm_config: VllmConfig): + # get prefill tp and dp size from extra config + prefill_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {}) + + assert "tp_size" in prefill_parallel_config.keys() + self._prefill_tp_size = prefill_parallel_config["tp_size"] + + assert "dp_size" in prefill_parallel_config.keys() + self._prefill_dp_size = prefill_parallel_config["dp_size"] + + # get decode tp and dp size from extra config + decode_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {}) + assert "tp_size" in decode_parallel_config.keys() + self._decode_tp_size = decode_parallel_config["tp_size"] + assert "dp_size" in decode_parallel_config.keys() + self._decode_dp_size = decode_parallel_config["dp_size"] + + def _initialize( + self, + hostname: str, + device_name: Optional[str], + ) -> None: + """Initialize the mooncake instance.""" + device_name = device_name if device_name is not None else "" + ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend", + device_name) + if ret_value != 0: + raise RuntimeError( + f"Mooncake initialization failed with ret_value: {ret_value}") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data.""" + + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + self.first_kv_cache = first_kv_cache + + # TODO(tms): Find a more robust way to detect and handle MLA + self.use_mla = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 2] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[0] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + self.kv_caches_base_addr = kv_caches_base_addr + + # After KV Caches registered, start the sending or receiving thread. + metadata = MooncakeAgentMetadata( + te_rpc_port=self.te_rpc_port, + kv_caches_base_addr=self.kv_caches_base_addr, + ) + if self.vllm_config.kv_transfer_config.is_kv_producer: + ready_event = threading.Event() + self.kv_send_layer_thread = KVCacheSendingLayerThread( + engine=self.engine, + total_layers=self.total_layers, + ready_event=ready_event, + tp_rank=self.tp_rank, + pd_head_ratio=self.pd_head_ratio, + num_head_replica=self.num_head_replica, + kv_cache_base_addr=self.kv_caches_base_addr, + use_mla=self.use_mla, + block_len=self.block_len, + first_kv_cache=first_kv_cache, + callback_func=self.send_done_send_signal) + self.kv_send_layer_thread.start() + ready_event.wait() + + if self.vllm_config.kv_transfer_config.is_kv_consumer: + ready_event = threading.Event() + self.kv_recv_layer_thread = KVCacheRecvingLayerThread( + self.tp_rank, self.side_channel_port, self.tp_size, + self.pd_head_ratio, self.engine_id, metadata, ready_event) + self.kv_recv_layer_thread.start() + ready_event.wait() + + def _register(self, ptr, length): + logger.info( + "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " + "block_lens=%s", ptr, length, self.num_blocks, self.block_len) + ret_value = self.engine.register_memory(ptr, length) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed.") + + def _access_metaserver(self, url, message): + success = False + retry = 0 + while retry < 3 and success is False: + retry += 1 + try: + self.metaserver_client.post(url, json=message) + success = True + except Exception as e: + logger.error( + f"Failed to connect to metaserver: {url}, retry {retry} time." + ) + if retry == 3: + raise e + + def get_finished(self) -> tuple[set[str], set[str]]: + done_recving = ( + self.kv_recv_layer_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.vllm_config.kv_transfer_config.is_kv_consumer else set()) + if len(done_recving) > 0: + logger.info( + "Number of completed KV cache recv requests: %d, receive " + "requests: %d", 0, len(done_recving)) + return set(), done_recving + + def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): + """Start loading KV blocks from remote engine.""" + self.current_layer = 0 + if self.vllm_config.kv_transfer_config.is_kv_consumer: + for req_id, meta in metadata.requests.items(): + if self.tp_rank % self.tp_size == 0: + logger.info( + f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" + ) + # All parameters here should appear in the returned dict of + # request_finished in the scheduler side except "request_id". + kv_transfer_params = dict( + token_ids=meta.token_ids, + request_id=req_id, + do_remote_prefill=False, + do_remote_decode=True, + remote_block_ids=meta.local_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + future = self.executor.submit( + self._access_metaserver, + url=meta.metaserver, + message=kv_transfer_params, + ) + + def handle_exception(future): + if future.exception(): + logger.error( + f"Access metaserver fail: {future.exception()}" + ) + + future.add_done_callback(handle_exception) + assert self.kv_recv_layer_thread is not None + with self.kv_recv_layer_thread.lock: + self.kv_recv_layer_thread.task_tracker[req_id] = 0 + + def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, + torch.Tensor], + attn_metadata: "AttentionMetadata", + connector_metadata: MooncakeLayerwiseConnectorMetadata, + **kwargs) -> None: + """MooncakeLayerwiseConnector does not save explicitly.""" + if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( + ): + # enable decode prefix cache + for request in connector_metadata.requests.values(): + assert len(request.local_block_ids) >= len( + request.remote_block_ids + ), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num." + request.local_block_ids = request.local_block_ids[ + -len(request.remote_block_ids):] + if self.pd_head_ratio != 1: + + def sort_kv_cache(input_kv: list[list[int]]): + return torch.cat([ + torch.chunk(tensor, self.pd_head_ratio, dim=0)[x] + for x in range(self.pd_head_ratio) + for tensor in input_kv + ]) + + total_block_ids = [ + request.local_block_ids + for request in connector_metadata.requests.values() + ] + keys = [ + kv_layer[0][block_ids].reshape( + -1, *kv_layer[0].shape[2:]).clone() + for block_ids in total_block_ids + ] + values = [ + kv_layer[1][block_ids].reshape( + -1, *kv_layer[1].shape[2:]).clone() + for block_ids in total_block_ids + ] + key_block_size = keys[0].size(0) // len(total_block_ids[0]) + value_block_size = values[0].size(0) // len(total_block_ids[0]) + keys = sort_kv_cache(keys) # [req1_key, req2_key] + values = sort_kv_cache(values) + (keys, + values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, + values) + key_start_id = 0 + value_start_id = 0 + else: + key = None + value = None + for req_id, req_meta in connector_metadata.requests.items(): + logger.debug( + f"Add request {req_id} to kv send layer thread. {req_meta=}" + ) + if self.pd_head_ratio != 1: + key_block_num = len( + req_meta.local_block_ids) * key_block_size + value_block_num = len( + req_meta.local_block_ids) * value_block_size + key = keys[key_start_id:key_start_id + key_block_num] + value = values[value_start_id:value_start_id + + value_block_num] + key_start_id += key_block_num + value_start_id += value_block_num + req_meta_update = self.update_decoder_info(req_id, req_meta) + assert self.kv_send_layer_thread is not None + self.kv_send_layer_thread.send_queue.put( + (req_id, req_meta_update, self.current_layer, key, value)) + self.current_layer += 1 + + def _get_remote_socket( + self, remote_host: str, + remote_handshake_port: int) -> zmq.Socket: # type: ignore + """Get a socket to the remote host.""" + remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) + with self.remote_sockets_lock: + if self.remote_sockets[remote_path]: + return self.remote_sockets[remote_path].popleft() + + ctx = zmq.Context() # type: ignore + sock = make_zmq_socket( + ctx=ctx, + path=remote_path, + socket_type=zmq.REQ, # type: ignore + bind=False) + sock.setsockopt( + zmq.SNDTIMEO, # type: ignore + int(self.timeout * 1000)) + self.remote_poller.register(sock, zmq.POLLIN) # type: ignore + return sock + + def update_decoder_info(self, req_id, req_meta): + req_meta_update = copy.deepcopy(req_meta) + if self.pd_tp_ratio > 1: + req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank // self.pd_tp_ratio + else: + req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank + if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ + req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: + try: + encoded_data = self.encoder.encode((GET_META_MSG, req_id)) + sock = self._get_remote_socket(req_meta_update.remote_host, + req_meta_update.remote_port) + ensure_zmq_send(sock, encoded_data) + metadata_bytes = ensure_zmq_recv(sock, self.remote_poller) + agent_meta = self.decoder.decode(metadata_bytes) + except Exception as e: + logger.error( + f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}" + ) + assert req_meta_update.remote_engine_id != self.engine_id, ( + f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id " + f"{self.local_engine_id}.") + self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][ + req_meta_update.remote_port] = agent_meta.kv_caches_base_addr + self.remote_te_port[req_meta_update.remote_engine_id][ + req_meta_update.remote_port] = agent_meta.te_rpc_port + logger.info( + f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" + ) + req_meta_update.remote_te_rpc_port = self.remote_te_port[ + req_meta_update.remote_engine_id][req_meta_update.remote_port] + req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ + req_meta_update.remote_engine_id][req_meta_update.remote_port] + return req_meta_update + + def send_done_send_signal(self, req_id, req_meta): + logger.info("Sending done sending signal for request %s to %s:%d", + req_id, req_meta.remote_host, req_meta.remote_port) + try: + path = make_zmq_path("tcp", req_meta.remote_host, + req_meta.remote_port) + msg_encoder = msgspec.msgpack.Encoder() + encoded_data = msg_encoder.encode((DONE_SENDING_MSG, req_id)) + with zmq_ctx(zmq.REQ, path) as sock: # type: ignore + ensure_zmq_send(sock, encoded_data) + ack = sock.recv() + if ack != b"ACK": + raise ValueError(f"Unexpected ACK response: {ack}") + except Exception as e: + logger.error( + f"Sending done sending signal for request {req_id} to {req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}" + ) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None # type: ignore + try: + ctx = zmq.Context() # type: ignore + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) # type: ignore + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +def group_concurrent_contiguous( + src: List[int], + dst: List[int] = [] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + """Vectorised NumPy implementation.""" + if not dst: + src_only_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) + + if src_only_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_only_indices) != 1))[0] + 1 + src_groups = np.split(src_only_indices, brk) + src_groups = [g.tolist() for g in src_groups] + + return src_groups, [] + + else: + src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) + dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64) + + if src_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) + | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups + + +def string_to_int64_hash(input_str): + """ + Hash the string using SHA-256 and convert it into an int64 integer. + """ + hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest() + trunked_bytes = hashed_bytes[:8] + uint64_value = struct.unpack(" 0: + logger.warning( + f"Send failed: {e}, retrying... ({retries_left} " + "attempts left)") + time.sleep(0.1) + else: + logger.error(f"Send failed after all retries: {e}") + raise RuntimeError(f"Failed to send data after {max_retries} " + f"retries: {e}") + + +def ensure_zmq_recv( + socket: zmq.Socket, # type: ignore + poller: zmq.Poller, # type: ignore + timeout: float = 1.0, + max_retries: int = 3) -> bytes: + retries_left = max_retries + while True: + try: + if dict(poller.poll(int(timeout * 1000))): # milliseconds + data = socket.recv() + return data + else: + raise zmq.ZMQError("Receive timeout") # type: ignore + except zmq.ZMQError as e: # type: ignore + retries_left -= 1 + if retries_left > 0: + logger.warning(f"Receive failed: {e}, retrying... " + f"({retries_left} attempts left)") + time.sleep(0.1) + else: + logger.error(f"Receive failed after all retries: {e}") + raise RuntimeError( + f"Failed to receive data after {max_retries} " + f"retries: {e}") diff --git a/vllm_npu/distributed/parallel_state.py b/vllm_npu/distributed/parallel_state.py new file mode 100644 index 0000000..a870059 --- /dev/null +++ b/vllm_npu/distributed/parallel_state.py @@ -0,0 +1,196 @@ +from typing import Optional + +import torch +from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, + init_model_parallel_group) + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config + +# Currently, mc2 op need their own group coordinator. +_MC2: Optional[GroupCoordinator] = None +_MLP_TP: Optional[GroupCoordinator] = None +_OTP: Optional[GroupCoordinator] = None +_LMTP: Optional[GroupCoordinator] = None +_P_TP: Optional[GroupCoordinator] = None + + +def get_mc2_group() -> GroupCoordinator: + assert _MC2 is not None, ("mc2 group is not initialized") + return _MC2 + + +def get_otp_group() -> GroupCoordinator: + assert _OTP is not None, ( + "output tensor parallel group is not initialized") + return _OTP + + +def get_lmhead_tp_group() -> GroupCoordinator: + assert _LMTP is not None, ( + "lm head tensor parallel group is not initialized") + return _LMTP + + +def get_mlp_tp_group() -> GroupCoordinator: + assert _MLP_TP is not None, ("mlp group is not initialized") + return _MLP_TP + + +def get_p_tp_group() -> GroupCoordinator: + assert _P_TP is not None, ( + "distributed prefill tensor parallel group is not initialized") + return _P_TP + + +def model_parallel_initialized(): + return (_MC2 is not None) + + +def init_ascend_model_parallel(parallel_config: ParallelConfig, ): + if model_parallel_initialized(): + return + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + backend = torch.distributed.get_backend(get_world_group().device_group) + + # The layout of all ranks: ExternalDP * EP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + all_ranks = torch.arange(world_size).reshape( + -1, parallel_config.data_parallel_size * + parallel_config.tensor_parallel_size) + + pd_tp_ratio = get_ascend_config().pd_tp_ratio + pd_head_ratio = get_ascend_config().pd_head_ratio + global _P_TP + assert _P_TP is None, ( + "distributed prefill tensor parallel group is already initialized") + prefill_tensor_model_parallel_size = pd_tp_ratio + # divide alltoall groups + if pd_head_ratio > 1 and get_current_vllm_config( + ).kv_transfer_config.is_kv_producer: + num_head_replica = get_ascend_config().num_head_replica + remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio + if num_head_replica <= 1: + group_ranks = all_ranks.view( + -1, prefill_tensor_model_parallel_size).unbind(0) + else: + group_ranks = all_ranks.clone().view( + parallel_config.data_parallel_size, -1, + num_head_replica) # [DP_size, num_head, num_head_replica] + group_ranks = group_ranks.permute(0, 2, 1) + group_ranks = group_ranks.reshape( + -1, + group_ranks.size(-1)) # [DP_size * num_head_replica, num_head] + alltoall_group_size = group_ranks.size(-1) // remote_tp_size + group_ranks = group_ranks.unsqueeze(-1).view( + parallel_config.data_parallel_size, num_head_replica, -1, + alltoall_group_size + ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] + group_ranks = group_ranks.reshape(-1, + alltoall_group_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + local_rank = get_world_group().local_rank + num = next( + (i for i, ranks in enumerate(group_ranks) if local_rank in ranks), + None) + _P_TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name=f"p_tp_{num}") + + global _MC2 + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + _MC2 = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="mc2") + if envs_ascend.vllm_npu_ENABLE_MLP_OPTIMIZE: + global _MLP_TP + assert _MLP_TP is None, ( + "mlp tensor model parallel group is already initialized") + + mlp_tp = parallel_config.data_parallel_size + + all_ranks_mlp_head = torch.arange(world_size).reshape( + -1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa + group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + # message queue broadcaster is only used in tensor model parallel group + _MLP_TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="mlp_tp") + + # If oproj tensor parallel size is set, we will create a group for it. + otp_size = get_ascend_config().oproj_tensor_parallel_size + if otp_size is not None: + group_ranks = [] + global _OTP + num_oproj_tensor_parallel_groups: int = (world_size // otp_size) + for i in range(num_oproj_tensor_parallel_groups): + ranks = list(range(i * otp_size, (i + 1) * otp_size)) + group_ranks.append(ranks) + _OTP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="otp") + + lmhead_tensor_parallel_size = get_ascend_config( + ).lmhead_tensor_parallel_size + if lmhead_tensor_parallel_size is not None: + group_ranks = [] + global _LMTP + num_lmhead_tensor_parallel_groups: int = (world_size // + lmhead_tensor_parallel_size) + for i in range(num_lmhead_tensor_parallel_groups): + ranks = list( + range(i * lmhead_tensor_parallel_size, + (i + 1) * lmhead_tensor_parallel_size)) + group_ranks.append(ranks) + _LMTP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="lmheadtp") + + +def get_mlp_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_mlp_tp_group().world_size + + +def get_mlp_tensor_model_parallel_rank(): + """Return world size for the tensor model parallel group.""" + return get_mlp_tp_group().rank_in_group + + +def destroy_ascend_model_parallel(): + global _MC2 + if _MC2: + _MC2.destroy() + _MC2 = None + + global _MLP_TP + if _MLP_TP: + _MLP_TP.destroy() + _MLP_TP = None + + global _LMTP + if _LMTP: + _LMTP.destroy() + _LMTP = None + + global _OTP + if _OTP: + _OTP.destroy() + _OTP = None + + global _P_TP + if _P_TP: + _P_TP.destroy() + _P_TP = None diff --git a/vllm_npu/distributed/utils.py b/vllm_npu/distributed/utils.py new file mode 100644 index 0000000..a633a35 --- /dev/null +++ b/vllm_npu/distributed/utils.py @@ -0,0 +1,61 @@ +import os + +import torch +import torch.distributed as dist + +from vllm_npu.distributed.parallel_state import get_p_tp_group + + +def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, + value: torch.TensorType): + if pd_tp_ratio <= 1: + return None, None + elif key is None or value is None: + raise ValueError("key or value is None") + k_output = alltoall_and_rearrange(pd_tp_ratio, key) + v_output = alltoall_and_rearrange(pd_tp_ratio, value) + return k_output, v_output + + +def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor): + num_kv_heads = input_tensor.size(1) + output_tensor = torch.zeros_like(input_tensor) + dist.all_to_all_single(output_tensor, + input_tensor, + group=get_p_tp_group().device_group) + input_tensor = 0 + result = rearrange_output(output_tensor, tp_ratio, num_kv_heads) + output_tensor = 0 + return result + + +def rearrange_output(base_output: torch.Tensor, cut_num: int, + num_kv_heads: int): + size_0 = base_output.size(0) + if size_0 % cut_num != 0: + raise ValueError( + f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]" + ) + chunk_size = size_0 // cut_num + reshaped = base_output.view(cut_num, chunk_size, -1) + transposed = reshaped.transpose(0, 1) + return transposed.contiguous().view(size_0, num_kv_heads, -1) + + +def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + + +def get_transfer_timeout_value(): + ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "") + if len(ascend_transfer_timeout) > 0: + return int(ascend_transfer_timeout) + hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT', + '20')) # type: ignore + hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT', + '7')) # type: ignore + return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + + 3000) diff --git a/vllm_npu/envs.py b/vllm_npu/envs.py new file mode 100644 index 0000000..10a3846 --- /dev/null +++ b/vllm_npu/envs.py @@ -0,0 +1,183 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Any, Callable, Dict + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +env_variables: Dict[str, Callable[[], Any]] = { + # max compile thread number for package building. Usually, it is set to + # the number of CPU cores. If not set, the default value is None, which + # means all number of CPU cores will be used. + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + # The build type of the package. It can be one of the following values: + # Release, Debug, RelWithDebugInfo. If not set, the default value is Release. + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + # Whether to compile custom kernels. If not set, the default value is True. + # If set to False, the custom kernels will not be compiled. Please note that + # the sleep mode feature will be disabled as well if custom kernels are not + # compiled. + "COMPILE_CUSTOM_KERNELS": + lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), + # The CXX compiler used for compiling the package. If not set, the default + # value is None, which means the system default CXX compiler will be used. + "CXX_COMPILER": + lambda: os.getenv("CXX_COMPILER", None), + # The C compiler used for compiling the package. If not set, the default + # value is None, which means the system default C compiler will be used. + "C_COMPILER": + lambda: os.getenv("C_COMPILER", None), + # The version of the Ascend chip. If not set, the default value is + # ASCEND910B1(Available for A2 and A3 series). It's used for package building. + # Please make sure that the version is correct. + "SOC_VERSION": + lambda: os.getenv("SOC_VERSION", "ASCEND910B1"), + # If set, vllm-ascend will print verbose logs during compilation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + # The home path for CANN toolkit. If not set, the default value is + # /usr/local/Ascend/ascend-toolkit/latest + "ASCEND_HOME_PATH": + lambda: os.getenv("ASCEND_HOME_PATH", None), + # The path for HCCL library, it's used by pyhccl communicator backend. If + # not set, the default value is libhccl.so。 + "HCCL_SO_PATH": + lambda: os.environ.get("HCCL_SO_PATH", None), + # The version of vllm is installed. This value is used for developers who + # installed vllm from source locally. In this case, the version of vllm is + # usually changed. For example, if the version of vllm is "0.9.0", but when + # it's installed from source, the version of vllm is usually set to "0.9.1". + # In this case, developers need to set this value to "0.9.0" to make sure + # that the correct package is installed. + "VLLM_VERSION": + lambda: os.getenv("VLLM_VERSION", None), + # Whether to enable the trace recompiles from pytorch. + "vllm_npu_TRACE_RECOMPILES": + lambda: bool(int(os.getenv("vllm_npu_TRACE_RECOMPILES", '0'))), + # Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and + # GroupedMatmulFinalizeRouting operators are combined to implement EP. + "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": + lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0')) + ), + # Whether to enable DBO feature for deepseek model. + "vllm_npu_ENABLE_DBO": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_DBO", '0'))), + # Whether to enable the model execute time observe profile. Disable it when + # running vllm ascend in production environment. + "vllm_npu_MODEL_EXECUTE_TIME_OBSERVE": + lambda: bool(int(os.getenv("vllm_npu_MODEL_EXECUTE_TIME_OBSERVE", '0')) + ), + # Some models are optimized by vllm ascend. While in some case, e.g. rlhf + # training, the optimized model may not be suitable. In this case, set this + # value to False to disable the optimized model. + "USE_OPTIMIZED_MODEL": + lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), + # The tolerance of the kv cache size, if the difference between the + # actual kv cache size and the cached kv cache size is less than this value, + # then the cached kv cache size will be used. + "vllm_npu_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE": + lambda: int( + os.getenv("vllm_npu_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)), + # Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue. + # We'll remove this flag in the future once it's stable enough. + "vllm_npu_ENABLE_TOPK_TOPP_OPTIMIZATION": + lambda: bool( + int(os.getenv("vllm_npu_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))), + # `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is + # used for llmdatadist to build the communication topology for kv cache transfer, it is + # a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated + # pd. The rank table can be generated by adopting the script `gen_ranktable.sh` + # in vllm_npu's example folder. + "DISAGGREGATED_PREFILL_RANK_TABLE_PATH": + lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None), + # `LLMDataDistCMgrConnector` required variable. `vllm_npu_LLMDD_RPC_IP` is used as the + # rpc communication listening ip, which will be used to receive the agent metadata from the + # remote worker. + "vllm_npu_LLMDD_RPC_IP": + lambda: os.getenv("vllm_npu_LLMDD_RPC_IP", "0.0.0.0"), + # `LLMDataDistCMgrConnector` required variable. `vllm_npu_LLMDD_RPC_PORT` is used as the + # rpc communication listening port, which will be used to receive the agent metadata from the + # remote worker. + "vllm_npu_LLMDD_RPC_PORT": + lambda: int(os.getenv("vllm_npu_LLMDD_RPC_PORT", 5557)), + # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible + # and the mla_pa will be the default path of deepseek decode path. + "vllm_npu_MLA_PA": + lambda: int(os.getenv("vllm_npu_MLA_PA", 0)), + # Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled. + # this feature is supported in A2, and eager mode will get better performance. + "vllm_npu_ENABLE_MATMUL_ALLREDUCE": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_MATMUL_ALLREDUCE", '0'))), + # Whether to enable FlashComm optimization when tensor parallel is enabled. + # This feature will get better performance when concurrency is large. + "vllm_npu_ENABLE_FLASHCOMM1": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_FLASHCOMM1", '0'))), + # Whether to enable MLP weight prefetch, only used in small concurrency. + "vllm_npu_ENABLE_PREFETCH_MLP": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_PREFETCH_MLP", '0'))), + # buffer size for gate up prefetch + "vllm_npu_MLP_GATE_UP_PREFETCH_SIZE": + lambda: int( + os.getenv("vllm_npu_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), + # buffer size for down proj prefetch + "vllm_npu_MLP_DOWN_PREFETCH_SIZE": + lambda: int( + os.getenv("vllm_npu_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), + # Whether to enable dense model and general optimizations for better performance. + # Since we modified the base parent class `linear`, this optimization is also applicable to other model types. + # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. + "vllm_npu_ENABLE_DENSE_OPTIMIZE": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_DENSE_OPTIMIZE", '0'))), + # Whether to enable mlp optimize when tensor parallel is enabled. + # this feature in eager mode will get better performance. + "vllm_npu_ENABLE_MLP_OPTIMIZE": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_MLP_OPTIMIZE", '0'))), + # Determine the number of physical devices in a non-full-use scenario + # caused by the initialization of the Mooncake connector. + "PHYSICAL_DEVICES": + lambda: os.getenv("PHYSICAL_DEVICES", None), + # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. + "MSMONITOR_USE_DAEMON": + lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), + "vllm_npu_ENABLE_MLAPO": + lambda: bool(int(os.getenv("vllm_npu_ENABLE_MLAPO", '0'))), + # Whether to enable transpose weight and cast format to FRACTAL_NZ. + "vllm_npu_ENABLE_NZ": + lambda: int(os.getenv("vllm_npu_ENABLE_NZ", 1)), +} + +# end-env-vars-definition + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in env_variables: + return env_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(env_variables.keys()) \ No newline at end of file diff --git a/vllm_npu/eplb/__init__.py b/vllm_npu/eplb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/eplb/adaptor/__init__.py b/vllm_npu/eplb/adaptor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/eplb/adaptor/abstract_adaptor.py b/vllm_npu/eplb/adaptor/abstract_adaptor.py new file mode 100644 index 0000000..ab37fde --- /dev/null +++ b/vllm_npu/eplb/adaptor/abstract_adaptor.py @@ -0,0 +1,44 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor. +from abc import abstractmethod +from typing import Any + + +class EplbAdaptor(): + + def __init__(self, **args): + pass + + @abstractmethod + def get_rank_expert_workload(self): + raise NotImplementedError + + @abstractmethod + def get_init_expert_map(self, num_moe_layers: Any) -> Any: + raise NotImplementedError + + @abstractmethod + def do_update_expert_map(self, layer_id: Any, + updated_expert_map: Any) -> Any: + raise NotImplementedError + + @abstractmethod + def do_update_expert_weight(self, layer_id: Any, + local_expert_to_replace: Any, + buffer_tensor_id: Any) -> Any: + raise NotImplementedError diff --git a/vllm_npu/eplb/adaptor/vllm_adaptor.py b/vllm_npu/eplb/adaptor/vllm_adaptor.py new file mode 100644 index 0000000..0481ce9 --- /dev/null +++ b/vllm_npu/eplb/adaptor/vllm_adaptor.py @@ -0,0 +1,289 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor. +import json +from typing import Any + +import torch +import torch.distributed as dist +from vllm.logger import logger + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.eplb.adaptor.abstract_adaptor import EplbAdaptor + + +class VllmEplbAdaptor(EplbAdaptor): + + def __init__(self, model, **args): + super().__init__(**args) + self.model = model + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + self.param_dict = dict(self.model.named_parameters()) + if self.model.config.model_type == "qwen3_moe": + self.num_dense_layers = 0 + self.global_expert_num = self.model.config.num_experts + else: + self.num_dense_layers = self.model.config.first_k_dense_replace + self.global_expert_num = self.model.config.n_routed_experts + self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers + self.init_redundancy_expert = get_ascend_config( + ).init_redundancy_expert + + # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here + if self.model.quant_config is not None: + self.expert_weight_names = [ + "w13_weight", "w2_weight", "w13_weight_scale", + "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + ] + else: + self.expert_weight_names = ["w13_weight", "w2_weight"] + + self.expert_map_per_layer = dict( + ) # reference to expert map on device for expert map update + self.expert_map_per_layer_cpu = dict( + ) # copy of expert map on CPU to avoid device synchronize frequently + for layer_idx in range(self.num_moe_layers): + self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \ + self.model.get_expert_map(self.num_dense_layers + layer_idx) + + # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved + num_buffer_tensor = torch.where( + self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + self.buffer_tensor_list: list[list[Any]] = [ + [] for _ in range(num_buffer_tensor) + ] + self.init_buffer_tensor(num_buffer_tensor) + + self.expert_param_per_layer = dict() + self.init_expert_param_per_layer() + + self.log2phy_map_per_layer = dict() + for layer_idx in range(self.num_moe_layers): + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ + self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + self.all_topk_ids = [] + + def init_buffer_tensor(self, num_buffer_tensor): + for buffer_id in range(num_buffer_tensor): + for name in self.expert_weight_names: + complete_name = "model.layers." + str( + self.num_dense_layers) + ".mlp.experts." + name + expert_tensor = self.param_dict[complete_name].data[0] + if name in ["w13_weight", "w2_weight"]: + expert_tensor = expert_tensor.clone() + buffer_tensor = torch.empty_like(expert_tensor) + self.buffer_tensor_list[buffer_id].append(buffer_tensor) + + def init_expert_param_per_layer(self): + num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \ + ".mlp.experts." + self.expert_weight_names[0]].data.shape[0] + for moe_layer_id in range(self.num_moe_layers): + layer_idx = self.num_dense_layers + moe_layer_id + self.expert_param_per_layer[layer_idx] = list() + for local_expert_id in range(num_local_expert): + self.expert_param_per_layer[layer_idx].append([ + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name].data[local_expert_id] + for name in self.expert_weight_names + ]) + + def get_rank_expert_workload(self) -> torch.Tensor: + self.moe_load = self.model.get_all_moe_loads() + return self.moe_load + + def get_init_expert_map(self, num_moe_layers): + expert_map = self.model.get_all_expert_map(num_moe_layers) + if dist.is_initialized(): + world_size = dist.get_world_size() + + gathered = torch.empty( + (world_size, *expert_map.shape), # [W, L, E] + dtype=expert_map.dtype, + device=expert_map.device) + + dist.all_gather_into_tensor(gathered, expert_map) + all_maps = gathered.permute(1, 0, 2) + all_expert_maps = all_maps.cpu() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \ + all_expert_maps[layer_idx][self.rank_id] + + return all_expert_maps + + def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path): + + try: + expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor( + expert_map_path) + expert_map_all = self.local2global(expert_map_tensor) + except (TypeError, FileNotFoundError, OSError): + expert_map_all = self.determine_expert_map_all() + + for layer_idx in range(num_moe_layers): + if self.model.config.model_type == "qwen3_moe": + self.expert_map_per_layer_cpu[layer_idx] = \ + expert_map_all[layer_idx][self.rank_id] + else: + self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \ + expert_map_all[layer_idx][self.rank_id] + return expert_map_all + + def _expert_file_to_tensor(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + logger.error(f"failed to read expert_map_path: {expert_map_path}") + + def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): + if self.rank_id == 0: + num_local_experts = expert_maps.max() + 1 + expert_maps_local = self.global2local(expert_maps, + num_local_experts) + + expert_maps_list = expert_maps_local.tolist() + record: dict[str, Any] = { + "moe_layer_count": len(expert_maps_list), + "layer_list": [] + } + + for layer_idx, layer_data in enumerate(expert_maps_list): + layer_record: dict[str, Any] = { + "layer_id": layer_idx, + "device_count": len(layer_data), + "device_list": [] + } + + for device_idx, experts in enumerate(layer_data): + device_record = { + "device_id": device_idx, + "device_expert": experts + } + layer_record["device_list"].append(device_record) + + record["layer_list"].append(layer_record) + + with open(expert_map_record_path, "w") as f: + json.dump(record, f, indent=4) + + def do_update_expert_map(self, layer_id, updated_expert_map): + self.expert_map_per_layer[layer_id].copy_(updated_expert_map) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, + buffer_tensor_id): + for expert_tensor, buffer_tensor in zip( + self.expert_param_per_layer[layer_id][local_expert_to_replace], + self.buffer_tensor_list[buffer_tensor_id]): + expert_tensor.copy_(buffer_tensor) + logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") + + def do_update_log2phy_map(self, layer_id, updated_log2phy_map): + if self.log2phy_map_per_layer[layer_id] is not None: + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map) + + def global2local(self, placement: torch.Tensor, + E_local: int) -> torch.Tensor: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def determine_expert_map_all(self): + if self.world_size == 1: + local_ids = torch.arange(self.global_expert_num, dtype=torch.int32) + return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1) + + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.num_moe_layers, self.world_size, self.global_expert_num), + -1, + dtype=torch.int32) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + if r < self.init_redundancy_expert: + local_count += 1 + if end < self.global_expert_num: + end += 1 + else: + start -= 1 + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand( + self.num_moe_layers, -1) + + return expert_map_all diff --git a/vllm_npu/eplb/core/__init__.py b/vllm_npu/eplb/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/eplb/core/eplb_device_transfer_loader.py b/vllm_npu/eplb/core/eplb_device_transfer_loader.py new file mode 100644 index 0000000..67e4d56 --- /dev/null +++ b/vllm_npu/eplb/core/eplb_device_transfer_loader.py @@ -0,0 +1,134 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from enum import Enum + +import torch.distributed as dist +from vllm.logger import logger + + +class ExpertWeightUpdateState(Enum): + WAITING = 0 # waiting for updated expert_map by EplbWorker + READY = 1 # ready for d2d expert weights updating + TRANSFERRING = 2 # d2d finished and waiting for updating expert_map into model + + +class D2DExpertWeightLoader: + + def __init__(self): + self.comm_op_list = None + self.updated_expert_map = None + self.updated_log2phy_map = None + self.layer_id = -1 # layer id to be updated + self.state = ExpertWeightUpdateState.WAITING + self.recv_expert_list = [] + self.mock_flag = True + + def set_adator(self, eplb_adaptor): + self.eplb_adaptor = eplb_adaptor + + def generate_expert_d2d_transfer_task(self, expert_send_info, + expert_recv_info, updated_expert_map, + layer_id): + # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task + if self.state != ExpertWeightUpdateState.WAITING: + logger.warning_once( + "current d2d weight update tasks are on-going, cannot accept new weight update task" + ) + return + + self.updated_expert_map = updated_expert_map + + self.layer_id = layer_id + self.comm_op_list = [] + for send_info in expert_send_info: + dst_rank, global_expert_id_to_send = send_info + local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[ + layer_id][global_expert_id_to_send].item() + for src_tensor in self.eplb_adaptor.expert_param_per_layer[ + layer_id][local_expert_id]: + src_tensor = src_tensor.clone() + self.comm_op_list.append( + dist.P2POp(dist.isend, src_tensor, dst_rank)) + + buffer_tensor_id = 0 + for recv_info in expert_recv_info: + recv_rank, global_expert_id_to_recv = recv_info + for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[ + buffer_tensor_id]: + self.comm_op_list.append( + dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + local_expert_to_replace = self.updated_expert_map[ + global_expert_id_to_recv].item() + self.recv_expert_list.append( + (local_expert_to_replace, buffer_tensor_id)) + buffer_tensor_id += 1 + + self.state = ExpertWeightUpdateState.READY + + def set_log2phy_map(self, log2phy_map): + self.updated_log2phy_map = log2phy_map + + def asyn_expert_weight_transfer(self, reqs): + # Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched + if self.state != ExpertWeightUpdateState.READY: + return + + # set asynchronous stream for d2d expert weight transfer + if self.comm_op_list: + ret_list = dist.batch_isend_irecv(self.comm_op_list) + reqs.extend(ret_list) + + self.state = ExpertWeightUpdateState.TRANSFERRING + + def update_expert_map_and_weight(self, reqs): + # Only after send/recv tasks have been luanched, expert_map and weight can be updated + if self.state != ExpertWeightUpdateState.TRANSFERRING: + return + + # Waiting for send/recv tasks finish + for req in reqs: + req.wait() + + if self.comm_op_list is not None: + self.comm_op_list = None + + # update expert_map + self.eplb_adaptor.do_update_expert_map(self.layer_id, + self.updated_expert_map) + + # update log2phy_map + self.eplb_adaptor.do_update_log2phy_map(self.layer_id, + self.updated_log2phy_map) + + # update expert weight + buffer_tensor_id = 0 + for recv_expert_info in self.recv_expert_list: + local_expert_to_replace, buffer_tensor_id = recv_expert_info + self.eplb_adaptor.do_update_expert_weight(self.layer_id, + local_expert_to_replace, + buffer_tensor_id) + + logger.info( + f"[EPLB] finished update expert weight for layer: {self.layer_id}") + + self.recv_expert_list = [] + self.updated_expert_map = None + self.layer_id = -1 + self.state = ExpertWeightUpdateState.WAITING + + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError diff --git a/vllm_npu/eplb/core/eplb_utils.py b/vllm_npu/eplb/core/eplb_utils.py new file mode 100644 index 0000000..d7fd17e --- /dev/null +++ b/vllm_npu/eplb/core/eplb_utils.py @@ -0,0 +1,189 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils. +import os.path +import random +import sys + +import torch +from vllm.logger import logger + + +def determine_default_expert_map(global_expert_num, world_size, rank_id, + global_redundant_expert_num): + if world_size == 1: + local_ids = torch.arange(global_expert_num, dtype=torch.int32) + return (global_expert_num, local_ids) + + local_num_experts = global_expert_num // world_size + + expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32) + + if rank_id < world_size - 1: + start = rank_id * local_num_experts + end = (rank_id + 1) * local_num_experts + local_count = local_num_experts + else: + start = rank_id * local_num_experts + end = global_expert_num + local_count = global_expert_num - rank_id * local_num_experts + + if isinstance(local_count, int): + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map[start:end] = local_ids + + return (local_count, expert_map) + + +def generate_log2phy_map(expert_map): + num_local_experts = expert_map.max() + 1 + log2phy_map = expert_map.clone() + num_ranks, num_global_expert = log2phy_map.shape + + row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \ + num_global_expert) * num_local_experts + log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1] + + for idx in range(num_global_expert): + positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0] + negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0] + num_rank_holding_expert = positive_rank_idx.size(0) + + if num_rank_holding_expert == 0: + log2phy_map[:, idx] = torch.full((num_ranks, ), + 0, + dtype=log2phy_map.dtype) + + if num_rank_holding_expert == 1: + log2phy_map[negative_rank_idx, idx] = torch.full( + (num_ranks - 1, ), + log2phy_map[positive_rank_idx, idx].item(), + dtype=log2phy_map.dtype) + else: + try: + random_list = [ + random.choice(log2phy_map[positive_rank_idx, idx]) + for _ in range(num_ranks - num_rank_holding_expert) + ] + log2phy_map[negative_rank_idx, + idx] = torch.tensor(random_list, + dtype=log2phy_map.dtype) + except Exception as e: + logger.error(f"Fail to get log2phy_map: {str(e)}") + + return log2phy_map + + +def determine_default_log2phy_map(global_expert_num, world_size, rank_id): + if world_size == 1: + local_ids = torch.arange(global_expert_num, dtype=torch.int32) + expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1) + log2phy_map_all = generate_log2phy_map(expert_map_all) + return log2phy_map_all[rank_id] + + local_num_experts = global_expert_num // world_size + + expert_map_all = torch.full((world_size, global_expert_num), + -1, + dtype=torch.int32) + + for r in range(world_size): + if r < world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = global_expert_num + local_count = global_expert_num - r * local_num_experts + + if isinstance(local_count, int): + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[r, start:end] = local_ids + + log2phy_map_all = generate_log2phy_map(expert_map_all) + + return log2phy_map_all[rank_id] + + +class EPLBParamUtils: + + @staticmethod + def check_iterations(iterations): + if not isinstance(iterations, int): + raise TypeError(f"The {iterations} is not int.") + if iterations <= 0: + raise ValueError( + f"The {iterations} can not less than or equal to 0.") + if iterations > sys.maxsize: + raise ValueError( + f"The {iterations} can not large than {sys.maxsize}") + + @staticmethod + def check_dynamic_eplb(dynamic_eplb): + if dynamic_eplb is None: + return + if not isinstance(dynamic_eplb, bool): + raise TypeError("The dynamic_eplb is not bool.") + if dynamic_eplb and os.getenv("DYNAMIC_EPLB", "false") != "true": + raise ValueError( + 'Can not enable dynamic_eplb when not export DYNAMIC_EPLB="true".' + ) + + @staticmethod + def check_expert_map_path(expert_map): + if expert_map is None: + return + if not isinstance(expert_map, str): + raise TypeError("The expert_map is not str.") + if not expert_map.strip(): + raise ValueError("The expert_map is not empty.") + _, ext = os.path.splitext(expert_map) + if ext.lower() != ".json": + raise TypeError("The expert_map is not json.") + if not os.path.exists(expert_map): + raise ValueError("The expert_map is not exist.") + try: + with open(expert_map, "w", encoding='utf-8') as f: + f.read() + except Exception as e: + raise IOError( + f"Fail read expert info from {expert_map}, please check the reading permission of {expert_map} : {e}" + ) + + @staticmethod + def check_expert_map_record_path(expert_map_record_path): + if expert_map_record_path is None: + return + if not isinstance(expert_map_record_path, str): + raise TypeError("The expert_map_record_path is not str.") + if not expert_map_record_path.strip(): + raise ValueError("The expert_map_record_path is empty.") + _, ext = os.path.splitext(expert_map_record_path) + if ext.lower() != ".json": + raise TypeError("The expert_map_record_path is not json.") + if os.getenv("EXPERT_MAP_RECORD", "false") != "true": + raise ValueError( + 'Can not enable expert_map_record_path when not export EXPERT_MAP_RECORD="true".' + ) + try: + with open(expert_map_record_path, "w", encoding='utf-8') as f: + f.write("") + except Exception as e: + raise IOError( + f"Fail write expert info to {expert_map_record_path}, please check the writing permission of {expert_map_record_path} : {e}" + ) diff --git a/vllm_npu/eplb/core/eplb_worker.py b/vllm_npu/eplb/core/eplb_worker.py new file mode 100644 index 0000000..6016b97 --- /dev/null +++ b/vllm_npu/eplb/core/eplb_worker.py @@ -0,0 +1,440 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from multiprocessing import Process, Queue +from typing import Any + +import networkx as nx # type: ignore +import numpy as np +import torch +import torch.distributed as dist +from vllm.logger import logger + +from vllm_npu.eplb.core.eplb_utils import generate_log2phy_map +from vllm_npu.eplb.core.policy.policy_factory import (DynamicConfig, + PolicyFactory) + + +class EplbWorker: + + def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): + self.policy_type = policy_type + self.policy = PolicyFactory.generate_policy(policy_type, + DynamicConfig()) + self.shared_dict = shared_dict + self.old_expert_maps = None + self.enable_d2d = enable_d2d + self.rank_id = dist.get_rank() + + def do_update(self): + # put data in to queue + # in process self.policy.generate_policy() + # get epxert table && tensor + + # async stream + # D2D + # H2D + # Get initial expert_map + torch.set_num_threads(1) + if self.old_expert_maps is None: + self.old_expert_maps = self.get_init_expert_maps() + if self.old_expert_maps is not None: + self.num_local_experts = self.old_expert_maps.max() + 1 + else: + raise ValueError("Failed to get expert_maps from shared_dict.") + + # Get MOE load information + load_info = self.fetch_and_sum_load_info() + if load_info is None: + return + + # Get the updated expert table based on the workload information + old_placement = self.global2local(self.old_expert_maps, + self.num_local_experts) + _, _, new_placement = self.calculate_rebalance_experts( + load_info, old_placement) + + if not torch.is_tensor(new_placement): + new_placement = torch.tensor(new_placement) + self.check_expert_placement(old_placement, new_placement) + new_expert_maps = self.local2global(new_placement) + self.update_expert_map(new_expert_maps) + + if self.policy_type == 2: + update_info = self.compose_expert_update_info_bipartite( + new_expert_maps, self.old_expert_maps) + else: + update_info = self.compose_expert_update_info_greedy( + new_expert_maps, self.old_expert_maps) + self.old_expert_maps = new_expert_maps + logger.info("EPLB Process compute complete") + + packed_update_info = self.pack_update_info(update_info) + + return packed_update_info + + def check_expert_placement(self, old_placement, new_placement): + num_layers = old_placement.shape[0] + num_ranks = old_placement.shape[1] + + for layer_id in range(num_layers): + # check if any logical expert is not placed on any rank + if torch.unique(new_placement[layer_id]).numel() < torch.unique( + old_placement[layer_id]).numel(): + logger.error( + f"There exists expert not placed on any rank in layer {layer_id}" + ) + new_placement[layer_id] = old_placement[layer_id] + continue + + for rank_id in range(num_ranks): + new_placement_check = new_placement[layer_id][rank_id] + old_placement_check = old_placement[layer_id][rank_id] + + # check if same logical experts are placed on the same NPU + if new_placement_check.numel() != torch.unique( + new_placement_check).numel(): + logger.error( + f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" + ) + new_placement[layer_id] = old_placement[layer_id] + break + + # check if there is any experts movement inside one NPU + expert_not_move = torch.isin(new_placement_check, + old_placement_check) + if not torch.equal(new_placement_check[expert_not_move], + old_placement_check[expert_not_move]): + logger.error( + f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid" + ) + new_placement[layer_id] = old_placement[layer_id] + break + + def compose_expert_update_info_bipartite(self, updated_expert_maps_org, + current_expert_maps_org): + # transform numpy array to torch tensor + updated_expert_maps = updated_expert_maps_org.clone() + current_expert_maps = current_expert_maps_org.clone() + updated_expert_maps = np.array(updated_expert_maps) + current_expert_maps = np.array(current_expert_maps) + + num_layers = current_expert_maps.shape[0] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + updated_expert_maps_this_layer_org = updated_expert_maps_org[ + layer_id] + + from typing import Any + + expert_send_info_this_layer: dict[Any, Any] = {} + expert_recv_info_this_layer: dict[Any, Any] = {} + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if (np.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer)).all(): + yield (expert_send_info_this_layer, + expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = np.where( + (current_expert_maps_this_layer == -1) + & (updated_expert_maps_this_layer != -1)) + + # record src ranks for potential transfer + src_ranks_set = dict() + for idx in range(len(dst_rank_indices)): + expert_id = experts_to_recv[idx].item() + if expert_id not in src_ranks_set: + src_ranks_set[expert_id] = np.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + + # loop until all experts are scheduled + while len(dst_rank_indices) > 0: + # construct bipartite graph + graph_expert_update: nx.Graph = nx.Graph() + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + # add src ranks + src_rank_ids = src_ranks_set[expert_id] + graph_expert_update.add_nodes_from(src_rank_ids, + bipartite=0) + # add dest rank + graph_expert_update.add_node(str(dst_rank_id), bipartite=1) + # add edges + for src_rank_id in src_rank_ids: + graph_expert_update.add_edge(src_rank_id, + str(dst_rank_id)) + + # graph may not be connected + connected_components = list( + nx.connected_components(graph_expert_update)) + all_matches = {} + # matching in this loop + for i, component in enumerate(connected_components): + subgraph = graph_expert_update.subgraph(component) + component_matching = nx.bipartite.maximum_matching( + subgraph) + all_matches.update(component_matching) + + for src_rank, dst_rank in all_matches.items(): + dst_rank = int(dst_rank) + assert src_rank != dst_rank + if graph_expert_update.nodes[src_rank]['bipartite'] == 0: + # currently not scheduled experts in rank dst_rank + experts_v = experts_to_recv[np.where( + dst_rank_indices == dst_rank)] + # src: src_rank, dest: dst_rank, expert: expert_id + expert_id = np.intersect1d( + experts_v, + np.where(current_expert_maps_this_layer[src_rank] + != -1))[0] + + # record send/rcv pairs + if src_rank not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank] = [] + if dst_rank not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank] = [] + expert_send_info_this_layer[src_rank].append( + (dst_rank, expert_id)) + expert_recv_info_this_layer[dst_rank].append( + (src_rank, expert_id)) + + remove_index = np.where( + np.logical_and(dst_rank_indices == dst_rank, + experts_to_recv == expert_id)) + + # update + dst_rank_indices = np.delete(dst_rank_indices, + remove_index) + experts_to_recv = np.delete(experts_to_recv, + remove_index) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases + def compose_expert_update_info_greedy(self, updated_expert_maps, + current_expert_maps): + num_layers = current_expert_maps.shape[0] + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + + expert_send_info_this_layer: dict[Any, Any] = {} + expert_recv_info_this_layer: dict[Any, Any] = {} + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if torch.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer): + yield (expert_send_info_this_layer, + expert_recv_info_this_layer, + updated_expert_maps_this_layer, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ + & (updated_expert_maps_this_layer != -1)) + + # Parse expert_ids each rank needs to send to other ranks + src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ + & (updated_expert_maps_this_layer == -1)) + + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + if dst_rank_id not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank_id] = [] + + if not torch.isin(torch.tensor(expert_id), + experts_to_send).any(): + # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert + candidate_src_rank_indices = torch.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + else: + candidate_src_rank_indices = src_rank_indices[ + experts_to_send == expert_id] + + # TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... + src_rank_id = candidate_src_rank_indices[0].item() + if src_rank_id not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank_id] = [] + + expert_send_info_this_layer[src_rank_id].append( + (dst_rank_id, expert_id)) + expert_recv_info_this_layer[dst_rank_id].append( + (src_rank_id, expert_id)) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer, layer_id) + + def calculate_rebalance_experts(self, load_info, old_placement): + """ + Compute `new_map` by calling the `rebalance_experts` method of the policy instance. + """ + if self.old_expert_maps is None: + return False, None, None + + changed, priority, new_map = self.policy.rebalance_experts( + old_placement, load_info) + return changed, priority, new_map + + def get_init_expert_maps(self): + """ + Read the initial expert_map from shared_dict. + """ + return self.shared_dict.get("expert_maps", None) + + def fetch_and_sum_load_info(self): + """ + Each time the subprocess is awakened, read the latest moe_load + (shape: [num_moe_layers, num_experts_per_layer]) from shared_dict. + """ + return self.shared_dict.get("moe_load", None) + + def update_expert_map(self, expert_maps): + + self.shared_dict["expert_maps"] = expert_maps + + def global2local(self, placement: torch.Tensor, + E_local: int) -> tuple[torch.Tensor, torch.Tensor]: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def pack_update_info(self, update_info_generator): + """ + Pack a list of update info tuples for efficient IPC. + """ + send_all = [] + recv_all = [] + maps = [] + log2phy_all = [] + layer_ids = [] + + for send_info, recv_info, new_expert_map, layer_id in update_info_generator: + + send_info_this_rank = send_info[ + self.rank_id] if self.rank_id in send_info else [] + recv_info_this_rank = recv_info[ + self.rank_id] if self.rank_id in recv_info else [] + send_all.append(send_info_this_rank) + recv_all.append(recv_info_this_rank) + + maps.append(new_expert_map[self.rank_id].numpy().tolist()) + + log2phy_map = generate_log2phy_map(new_expert_map) + log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist()) + + layer_ids.append(layer_id) + + return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + + +class EplbProcess: + + def __init__(self, + shared_dict, + policy_type: int = 0, + enable_d2d: bool = True): + """ + Args: + shared_dict: Cross-process shared dict returned by Manager().dict() + policy_type: Integer passed to PolicyFactory.generate_policy + enable_d2d: Whether to enable D2D loading + """ + self.shared_dict = shared_dict + self.policy_type = policy_type + self.enable_d2d = enable_d2d + self.planner_q: Queue[Any] = Queue() + self.block_update_q: Queue[Any] = Queue(maxsize=1) + + # Create EplbWorker instance + self.worker = EplbWorker(self.shared_dict, self.policy_type, + self.enable_d2d) + + def worker_process(self, planner_q, block_update_q): + """ + Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. + """ + while True: + try: + planner_q.get() + + packed_update_info = self.worker.do_update() + + while True: + if not block_update_q.empty(): + continue + block_update_q.put(packed_update_info) + break + + except Exception as e: + logger.warning(f"[EPLB subprocess Exiting due to error: {e}", + exc_info=True) + break + + def _launch_process(self): + """ + Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). + """ + proc = Process(target=self.worker_process, + args=(self.planner_q, self.block_update_q), + daemon=True) + + proc.start() + return proc diff --git a/vllm_npu/eplb/core/policy/__init__.py b/vllm_npu/eplb/core/policy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/eplb/core/policy/policy_abstract.py b/vllm_npu/eplb/core/policy/policy_abstract.py new file mode 100644 index 0000000..8ef58e2 --- /dev/null +++ b/vllm_npu/eplb/core/policy/policy_abstract.py @@ -0,0 +1,42 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host + ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed + num_die_per_host = 8 # Number of dies on each host machine + + +class EplbPolicy: + + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + Pass in the weights and return expert replication and placement under relevant constraints. + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass diff --git a/vllm_npu/eplb/core/policy/policy_dynamic_ep.py b/vllm_npu/eplb/core/policy/policy_dynamic_ep.py new file mode 100644 index 0000000..5e77f4d --- /dev/null +++ b/vllm_npu/eplb/core/policy/policy_dynamic_ep.py @@ -0,0 +1,389 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +from collections import defaultdict +from typing import cast + +import numpy as np + +from .policy_abstract import DynamicConfig, EplbPolicy + + +class DynamicTable: + # workload_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + workload_table = None + + # placement_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + placement_table = None + + +class DynamicEplb(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, + num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict: dict[int, int] = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][ + expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # Split hot (high-load) experts into redundant experts + def original_compute_balanced_pack_redundancy(origin_weights, card_num, + num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy: list[list[int]] = [ + [] for _ in range(route_expert_num) + ] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], + kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * ( + len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / ( + len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes: list[list[int]] = [[] for _ in range(card_num)] + boxes_weights: list[list[float]] = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], + kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + if item_id in boxes[i]: + continue + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] + == items_per_box + and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[ + min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # Split hot (high-load) experts into redundant experts + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, + num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy: list[list[int]] = [ + [] for _ in range(route_expert_num) + ] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], + kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * ( + len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / ( + len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes: list[list[int]] = [[] for _ in range(card_num)] + boxes_weights: list[list[float]] = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num, ), dtype='object') + all_weights[:route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], + kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] + == items_per_box + and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[ + min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # Scheme without redundant experts + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes: list[list[int]] = [[] for _ in range(card_num)] + boxes_weights: list[list[float]] = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] + == items_per_box + and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[ + min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu: int = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer: list[float] = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, + global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [ + int(x) for x in current_expert_table[layer_id][card_id] + ] + new_list = [ + int(x) for x in global_deployment[layer_id][card_id] + ] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[ + j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + def rebalance_experts(self, current_expert_table, expert_workload): + + info = DynamicTable() + info.workload_table = np.array(expert_workload) + info.placement_table = np.array(current_expert_table) + assert info.workload_table is not None + layer_num, num_npus, experts_per_npu = info.workload_table.shape + assert info.placement_table is not None + row = cast(np.ndarray, info.placement_table[0]) + expert_ids, counts = np.unique(row, return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, + info.workload_table, + num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer( + info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # Perform load balancing and deploy redundant experts + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + # Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards + if num_original_expert != expert_num: + raise ValueError( + f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}" + ) + + if num_npus <= 0: + raise ValueError("the number of NPUs must be greater than 0") + + if num_npus < num_redundancy_expert: + raise ValueError( + f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}" + ) + + # Number of experts deployed on each card includes one redundant expert + global_deployment: list[list[list[int]]] = [[[] + for _ in range(num_npus)] + for _ in range(layer_num)] + # Iterate to obtain the placement strategy for each layer, taking computational balance into account + max_heat_per_layer_after = np.zeros([layer_num]) + for layer in range(layer_num): + # Get the expert IDs and their corresponding workloads for the current layer; + # workloads need to be normalized, and one redundant expert is added per card + weights = np.zeros((expert_num, ), dtype='object') + for expert_id, workload_weight in enumerate( + layer_workloads[layer]): + weights[expert_id] = (expert_id, workload_weight) + + # Obtain the globally balanced placement strategy for each layer + result, layer_deployment = self.original_compute_balanced_pack_redundancy( + weights, num_npus, num_redundancy_expert) + + global_deployment[layer] = layer_deployment + max_heat_per_layer_after[layer] = max( + result, key=lambda x: x['total_weight'])['total_weight'] + + new_global_deployment = self.constraint_expert_local_exchange( + current_expert_table, global_deployment) + # Obtain the priority of each layer + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / + max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array( + new_global_deployment).tolist() diff --git a/vllm_npu/eplb/core/policy/policy_dynamic_ep_v2.py b/vllm_npu/eplb/core/policy/policy_dynamic_ep_v2.py new file mode 100644 index 0000000..a0b8d5d --- /dev/null +++ b/vllm_npu/eplb/core/policy/policy_dynamic_ep_v2.py @@ -0,0 +1,771 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +from abc import abstractmethod +from collections import defaultdict + +import numpy as np + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host + ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed + num_die_per_host = 8 # Number of dies on each host machine + + +class EplbPolicy: + + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + Pass in the weights and return expert replication and placement under relevant constraints. + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass + + +class DynamicTable: + # workload_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + workload_table = None + + # placement_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + placement_table = None + + +class DynamicEplbV2(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def safe_divide(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a / b + + @staticmethod + def safe_exact_divide(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a // b + + @staticmethod + def safe_mod(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a % b + + @staticmethod + def add_redundant(current_expert_table, expert_workload, + num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict: dict[int, int] = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][ + expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu: int = int(np.sum(counts - 1)) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer: list[float] = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + def calculate_initial_imbalance(self, global_deployment, + new_layer_workloads): + + device_num = global_deployment.shape[1] + layer_imbalance = [] + expert_num = np.zeros_like(new_layer_workloads) + for layer_id, layer in enumerate(global_deployment): + for device in layer: + for expert_id in device: + expert_num[layer_id][expert_id] += 1 + + for layer_id, layer in enumerate(global_deployment): + cur_layer_max_workload = 0 + total_workload = 0 + for box in layer: + box_workload = 0 + for expert_id in box: + update_workload = self.safe_divide( + new_layer_workloads[layer_id][expert_id], + expert_num[layer_id][expert_id]) + box_workload += update_workload + total_workload += update_workload + if cur_layer_max_workload < box_workload: + cur_layer_max_workload = box_workload + + cur_layer_imbalance = self.safe_divide( + cur_layer_max_workload, + (self.safe_divide(total_workload, device_num))) + layer_imbalance.append(cur_layer_imbalance) + + return layer_imbalance + + def compute_redundant_assignments(self, base_experts, + num_redundant_experts, num_experts): + + redundant_assignments: list[list[int]] = [[] + for _ in range(num_experts)] + current_weights = base_experts.copy() + + for i in range(num_redundant_experts): + sorted_indices = np.argsort([w for _, w in current_weights], + kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + target_expert = sorted_weights[0] + expert_id, original_weight = target_expert + + current_redundancy = len(redundant_assignments[expert_id]) + new_avg_weight = self.safe_divide( + original_weight * (current_redundancy + 1), + (current_redundancy + 2)) + + redundant_assignments[expert_id].append(num_experts + i) + current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) + + sorted_indices = np.argsort([w for _, w in current_weights], + kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + return redundant_assignments, sorted_weights + + def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos, + num_experts, num_exist_expert, + device_assignments, device_counts, + expert_from_device, + com_between_devices): + + current_weights = np.zeros((num_experts, ), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + current_weights[expert_id] = (expert_id, workload_weight) + + devices_with_slots = [] + for device_id, device_rendun_pos in enumerate(rendun_pos): + if len(device_rendun_pos) != 0: + devices_with_slots.append(device_id) + + while devices_with_slots: + sorted_indices = np.argsort([w for _, w in current_weights], + kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + for index, target_weight in enumerate(sorted_weights): + expert_id, original_weight = target_weight + if original_weight == -1: + print("Error:Redundant expert failure re-occurred") + redundancy_successful = True + break + redundancy_successful = False + for cur_device_id in devices_with_slots: + if expert_id not in device_assignments[cur_device_id]: + pos = rendun_pos[cur_device_id].pop() + if len(rendun_pos[cur_device_id]) == 0: + devices_with_slots = [ + device_id for device_id in devices_with_slots + if device_id != cur_device_id + ] + device_assignments[cur_device_id][pos] = expert_id + device_counts[cur_device_id] += 1 + communication_box_index = expert_from_device[expert_id] + com_between_devices[cur_device_id][ + communication_box_index] = expert_id + new_weight = self.safe_divide( + (original_weight * num_exist_expert[expert_id]), + (num_exist_expert[expert_id] + 1)) + sorted_weights[index] = (expert_id, new_weight) + num_exist_expert[expert_id] += 1 + redundancy_successful = True + break + if redundancy_successful: + break + + sorted_indices = np.argsort([id for id, _ in sorted_weights], + kind='stable') + sorted_weights = [sorted_weights[i][1] for i in sorted_indices] + + return sorted_weights, device_assignments, device_counts, com_between_devices + + @staticmethod + def prepare_expert_list(base_experts, redundant_assignments, + num_redundant_experts): + redundant_expert_list = np.empty(num_redundant_experts, dtype=object) + + index = 0 + num_experts = len(redundant_assignments) + for expert_id in range(num_experts): + for _ in redundant_assignments[expert_id]: + redundant_expert_list[index] = (expert_id, + next(w + for eid, w in base_experts + if eid == expert_id)) + index += 1 + + sorted_indices = np.argsort([w for _, w in redundant_expert_list], + kind='stable')[::-1] + return [redundant_expert_list[i] for i in sorted_indices] + + @staticmethod + def non_redundant_expert_information(origin_deployment, updated_weights, + rendun_pos): + + device_num = len(origin_deployment) + num_experts_per_device = origin_deployment.shape[1] + device_assignments = [[-1 for _ in range(num_experts_per_device)] + for _ in range(device_num)] + device_weights = [[0 for _ in range(num_experts_per_device)] + for _ in range(device_num)] + device_loads = [0] * device_num + device_counts = [0] * device_num + + for device_id, device in enumerate(origin_deployment): + for index, expert_id in enumerate(device): + if index in rendun_pos[device_id]: + continue + device_assignments[device_id][index] = expert_id + cur_weight = next( + weight for expert_id_of_weight, weight in updated_weights + if expert_id_of_weight == expert_id) + device_weights[device_id][index] = cur_weight + device_loads[device_id] += cur_weight + device_counts[device_id] += 1 + + return device_assignments, device_weights, device_loads, device_counts + + def recomputing_initial_weight(self, layer_workloads, device_assignments): + num_all_experts = [0] * len(layer_workloads) + for device in device_assignments: + for expert_id in device: + if expert_id != -1: + num_all_experts[expert_id] += 1 + + cur_layer_workload = [] + for expert_id, weight in enumerate(layer_workloads): + if num_all_experts[expert_id] == 0: + cur_layer_workload.append(-1) + else: + cur_layer_workload.append( + self.safe_divide(weight, num_all_experts[expert_id])) + + return cur_layer_workload, num_all_experts + + def distribute_redun_experts(self, layer_workloads, device_assignments, + device_weights, device_loads, device_counts, + redundant_expert_list, expert_from_device, + num_experts, rendun_pos): + + num_devices = len(device_assignments) + com_between_devices: list[dict[int, + int]] = [{} for _ in range(num_devices)] + + for expert_id, weight in redundant_expert_list: + candidate = -1 + for dev_id in range(num_devices): + if len(rendun_pos[dev_id]) == 0: + continue + if expert_id in device_assignments[dev_id]: + continue + if candidate == -1 or device_loads[dev_id] < device_loads[ + candidate]: + candidate = dev_id + if candidate != -1: + pos = rendun_pos[candidate].pop() + device_assignments[candidate][pos] = expert_id + device_weights[candidate][pos] = weight + device_loads[candidate] += weight + device_counts[candidate] += 1 + + communication_box_index = expert_from_device[expert_id] + com_between_devices[candidate][ + communication_box_index] = expert_id + + if any(sublist for sublist in rendun_pos): + cur_layer_workload, num_exist_expert = self.recomputing_initial_weight( + layer_workloads, device_assignments) + + update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments( + cur_layer_workload, rendun_pos, num_experts, num_exist_expert, + device_assignments, device_loads, expert_from_device, + com_between_devices) + + device_loads = [0] * len(device_counts) + for device_id, device in enumerate(device_assignments): + for index, expert_id in enumerate(device): + device_weights[device_id][index] = update_workload[ + expert_id] + device_loads[device_id] += update_workload[expert_id] + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + def redundancy_again(self, layer_workloads, origin_weights, + origin_deployment, expert_from_device, num_node, + is_node_redundant, rendun_pos): + + num_experts = len(origin_weights) + if is_node_redundant: + num_experts = num_experts * num_node + + num_redundant_experts = 0 + for rank_empty_pos in rendun_pos: + num_redundant_experts += len(rank_empty_pos) + + redundant_assignments, updated_weights = self.compute_redundant_assignments( + origin_weights, num_redundant_experts, num_experts) + + redundant_expert_list = self.prepare_expert_list( + updated_weights, redundant_assignments, num_redundant_experts) + + device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( + origin_deployment, updated_weights, rendun_pos) + + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( + layer_workloads, device_assignments, device_weights, device_loads, + device_counts, redundant_expert_list, expert_from_device, + num_experts, rendun_pos) + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def generate_allocation_report(device_assignments, device_weights, + device_loads, device_counts): + + report = [] + max_load = 0.0 + + for dev_id in range(len(device_assignments)): + current_load = device_loads[dev_id] + max_load = max(max_load, current_load) + + report.append({ + "device_id": dev_id + 1, + "assigned_experts": device_assignments[dev_id], + "expert_weights": device_weights[dev_id], + "total_load": current_load, + "expert_count": device_counts[dev_id] + }) + + return report, max_load + + @staticmethod + def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, + next_device_id, cur_layer_result, com_between_devices): + + cur_device_deployment = cur_layer_result[cur_device_id][ + 'assigned_experts'] + next_device_deployment = cur_layer_result[next_device_id][ + 'assigned_experts'] + + cur_device_weight = cur_layer_result[cur_device_id]['expert_weights'] + next_device_weight = cur_layer_result[next_device_id]['expert_weights'] + + cur_expert_id = cur_device_deployment[cur_exchange_index] + next_expert_id = next_device_deployment[next_exchange_index] + cur_device_deployment[cur_exchange_index] = next_expert_id + next_device_deployment[next_exchange_index] = cur_expert_id + + cur_expert_weight = cur_device_weight[cur_exchange_index] + next_expert_weight = next_device_weight[next_exchange_index] + cur_device_weight[cur_exchange_index] = next_expert_weight + next_device_weight[next_exchange_index] = cur_expert_weight + + cur_layer_result[cur_device_id][ + 'total_load'] += next_expert_weight - cur_expert_weight + cur_layer_result[next_device_id][ + 'total_load'] += cur_expert_weight - next_expert_weight + + com_between_devices[cur_device_id][next_device_id] = next_expert_id + com_between_devices[next_device_id][cur_device_id] = cur_expert_id + + def redundant_expert_deployment(self, layer_workloads, original_deployment, + expert_from_device, node_num, + is_node_redundant, rendun_pos): + device_num, per_device_expert_num = original_deployment.shape + route_expert_num = layer_workloads.shape[0] + per_node_device_num = self.safe_exact_divide(device_num, node_num) + per_node_route_expert_num = per_node_device_num * ( + per_device_expert_num - 1) + + weights = np.zeros((route_expert_num, ), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + weights[expert_id] = (expert_id, workload_weight) + + if is_node_redundant: + + device_assignments = [] + device_weights = [] + device_loads = [] + device_counts = [] + com_between_devices = [] + + for node_id in range(node_num): + cur_node_weights = weights[node_id * + per_node_route_expert_num:(node_id + + 1) * + per_node_route_expert_num] + cur_original_deployment = original_deployment[ + node_id * per_node_device_num:(node_id + 1) * + per_node_device_num] + + cur_node_rendun_pos = rendun_pos[node_id * + per_node_device_num:(node_id + + 1) * + per_node_device_num] + + cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( + layer_workloads, cur_node_weights, cur_original_deployment, + expert_from_device, node_num, is_node_redundant, + cur_node_rendun_pos) + device_assignments += cur_device_assignments + device_weights += cur_device_weights + device_loads += cur_device_loads + device_counts += cur_device_counts + com_between_devices += cur_com_between_devices + + else: + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( + layer_workloads, weights, original_deployment, + expert_from_device, node_num, is_node_redundant, rendun_pos) + report, max_load = self.generate_allocation_report( + device_assignments, device_weights, device_loads, device_counts) + + return report, max_load, com_between_devices + + @staticmethod + def two_device_exchange_experts(cur_device_result, exchange_device_result, + cur_exchanged_expert_id, + next_exchanged_expert_id, ave_workload, + increment, num_redundancy_expert): + + cur_device_weight = cur_device_result['expert_weights'] + next_device_weight = exchange_device_result['expert_weights'] + + cur_device_expert_id = cur_device_result['assigned_experts'] + next_device_expert_id = exchange_device_result['assigned_experts'] + + cur_device_total_weight = cur_device_result['total_load'] + next_device_total_weight = exchange_device_result['total_load'] + max_weight = max(cur_device_total_weight, next_device_total_weight) + + cur_exchange_index = -1 + next_exchange_index = -1 + + for index, weight in enumerate(cur_device_weight): + for next_index, next_weight in enumerate(next_device_weight): + change_flag = True + if (cur_device_expert_id[index] in next_device_expert_id + or next_device_expert_id[next_index] + in cur_device_expert_id): + change_flag = False + if (cur_device_expert_id[index] not in cur_exchanged_expert_id + ) and (next_device_expert_id[next_index] + not in next_exchanged_expert_id) and change_flag: + + cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight + next_total_weight_after_exchange = next_device_total_weight - next_weight + weight + exchange_max_weight = max( + cur_total_weight_after_exchange, + next_total_weight_after_exchange) + if exchange_max_weight < max_weight and ( + max_weight - + exchange_max_weight) >= (ave_workload * increment): + max_weight = exchange_max_weight + cur_exchange_index = index + next_exchange_index = next_index + + return cur_exchange_index, next_exchange_index + + def expert_exchange_between_devices(self, + ave_workload, + increment, + cur_layer_result, + com_between_devices, + num_redundancy_expert, + node_idx=0, + per_node_device_num=0, + is_node_redundant=False): + + if is_node_redundant: + cur_devices_result = cur_layer_result[node_idx * + per_node_device_num: + (node_idx + 1) * + per_node_device_num] + else: + cur_devices_result = cur_layer_result + + devices_total_weight = [] + for device in cur_devices_result: + devices_total_weight.append( + (device['total_load'], device['device_id'] - 1)) + + exchange_frequency = 100 + while exchange_frequency > 0: + exchange_frequency -= 1 + devices_total_weight.sort(key=lambda x: x[0]) + max_weight_device_id = devices_total_weight[-1][1] + exchange = False + for index in range(0, len(devices_total_weight) - 1): + min_weight_device_id = devices_total_weight[index][1] + if min_weight_device_id not in com_between_devices[ + max_weight_device_id]: + cur_exchanged_expert_id = list( + com_between_devices[max_weight_device_id].values()) + next_exchanged_expert_id = list( + com_between_devices[min_weight_device_id].values()) + + cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( + cur_layer_result[max_weight_device_id], + cur_layer_result[min_weight_device_id], + cur_exchanged_expert_id, next_exchanged_expert_id, + ave_workload, increment, num_redundancy_expert) + + if cur_exchange_index != -1: + self.exchange_expert(cur_exchange_index, + next_exchange_index, + max_weight_device_id, + min_weight_device_id, + cur_layer_result, + com_between_devices) + + devices_total_weight[-1] = ( + cur_layer_result[max_weight_device_id] + ['total_load'], max_weight_device_id) + devices_total_weight[index] = ( + cur_layer_result[min_weight_device_id] + ['total_load'], min_weight_device_id) + exchange = True + break + + if not exchange: + break + + def exchange_experts(self, layer_result, layer_com_between_devices, + num_nodes, device_num, is_node_redundant, + ave_workload, increment, num_redundancy_expert, + org_deployment): + + global_deployment = [] + + if is_node_redundant: + per_node_device_num = self.safe_exact_divide(device_num, num_nodes) + for node_idx in range(num_nodes): + self.expert_exchange_between_devices( + ave_workload, increment, layer_result, + layer_com_between_devices, num_redundancy_expert, node_idx, + per_node_device_num, is_node_redundant) + else: + self.expert_exchange_between_devices(ave_workload, increment, + layer_result, + layer_com_between_devices, + num_redundancy_expert) + + max_workload = 0 + for box in layer_result: + global_deployment.append(box['assigned_experts']) + if max_workload < box['total_load']: + max_workload = box['total_load'] + + global_deployment = np.array(global_deployment) + + return global_deployment, max_workload + + def count_elements(self, lst): + count = 0 + for item in lst: + if isinstance(item, list): + count += self.count_elements(item) + else: + count += 1 + return count + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, + global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [ + int(x) for x in current_expert_table[layer_id][card_id] + ] + new_list = [ + int(x) for x in global_deployment[layer_id][card_id] + ] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[ + j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + def rebalance_experts(self, + current_expert_table, + expert_workload, + is_node_redundant=False, + increment=0.01): + info = DynamicTable() + info.workload_table = expert_workload.numpy() + info.placement_table = current_expert_table.numpy() + assert info.workload_table is not None + layer_num, num_npus, experts_per_npu = info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], + return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, + info.workload_table, + num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer( + info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + num_node = self.safe_exact_divide(num_npus, 8) + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + expert_from_device = np.zeros((layer_num, num_original_expert)) + + if num_original_expert != expert_num: + raise ValueError( + f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})" + ) + + if num_npus <= 0: + raise ValueError("The number of NPUs must be greater than 0") + + if num_npus < num_redundancy_expert: + raise ValueError( + f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})" + ) + + global_deployment: list[list[list[int]]] = [[[] + for _ in range(num_npus)] + for _ in range(layer_num)] + layer_initial_imbalance = self.calculate_initial_imbalance( + info.placement_table, layer_workloads) + max_heat_per_layer_after = np.zeros([layer_num]) + sum_num = 0 + for layer in range(layer_num): + # print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer]) + if layer_initial_imbalance[layer] < 1.01: + global_deployment[layer] = info.placement_table[layer] + continue + + ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), + num_npus) + + rendun_pos: list[list[int]] = [[] for _ in range(num_npus)] + existing_experts = set() + for device_id, device in enumerate(info.placement_table[layer]): + for index, expert_id in enumerate(device): + if expert_id not in existing_experts: + existing_experts.add(expert_id) + expert_from_device[layer][expert_id] = device_id + else: + rendun_pos[device_id].append(index) + + result, max_workload, com_between_devices = self.redundant_expert_deployment( + layer_workloads[layer], info.placement_table[layer], + expert_from_device[layer], num_node, is_node_redundant, + rendun_pos) + # print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload)) + + global_deployment[layer], new_max_workload = self.exchange_experts( + result, com_between_devices, num_node, num_npus, + is_node_redundant, ave_workload, increment, + num_redundancy_expert, info.placement_table[layer]) + # print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload)) + + for device_id in range(num_npus): + com_between_devices[device_id] = { + key: value + for key, value in com_between_devices[device_id].items() + } + sum_num += self.count_elements(com_between_devices[device_id]) + + max_heat_per_layer_after[layer] = max( + result, key=lambda x: x['total_load'])['total_load'] + + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append( + self.safe_divide(max_heat_per_layer_after[layer_idx], + max_heat_per_layer_before[layer_idx])) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + new_global_deployment = self.constraint_expert_local_exchange( + current_expert_table, global_deployment) + + return change, per_layer_priority, np.array( + new_global_deployment).tolist() diff --git a/vllm_npu/eplb/core/policy/policy_factory.py b/vllm_npu/eplb/core/policy/policy_factory.py new file mode 100644 index 0000000..bbf7315 --- /dev/null +++ b/vllm_npu/eplb/core/policy/policy_factory.py @@ -0,0 +1,33 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory. +from .policy_abstract import DynamicConfig, EplbPolicy +from .policy_dynamic_ep import DynamicEplb +from .policy_dynamic_ep_v2 import DynamicEplbV2 +from .policy_flashlb import FlashLB +from .policy_random import RandomLoadBalance + + +class PolicyFactory: + + @staticmethod + def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: + policy = { + # Constraint applying Dynamic EPLB policy V2: + # If there exists redundant expert: + # only one redundant expert can be placed in one NPU and its physical expert index must be 0 + + # Applying greedy d2d expert weight update composing + 0: + RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3 + 1: + DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load + 2: + DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle + 3: + FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment + } + policy_class = policy.get(policy_type, RandomLoadBalance) + policy_instance = policy_class(config) + if policy_type == 3: + policy_instance.warm_up() + return policy_instance \ No newline at end of file diff --git a/vllm_npu/eplb/core/policy/policy_flashlb.py b/vllm_npu/eplb/core/policy/policy_flashlb.py new file mode 100644 index 0000000..2bf6551 --- /dev/null +++ b/vllm_npu/eplb/core/policy/policy_flashlb.py @@ -0,0 +1,651 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. + +import logging +from collections import deque +from typing import Dict + +import numpy as np +import torch +from numba import njit # type: ignore + +from .policy_abstract import DynamicConfig, EplbPolicy + +numba_logger = logging.getLogger("numba") +numba_logger.setLevel(logging.WARNING) + + +@njit +def compute_piece_counts(X, P, stage_weights): + n_stage, N = X.shape + S = P - N + pieces = np.ones(N, dtype=np.int32) + unit = X / pieces # unit[i, j] = X[i, j] / pieces[j] + + for _ in range(S): + deltas = np.zeros(N, dtype=np.float32) + for i in range(n_stage): + # Find top1 and top2 + idx1 = -1 + idx2 = -1 + val1 = -1.0 + val2 = -1.0 + for j in range(N): + v = unit[i, j] + if v > val1: + val2 = val1 + idx2 = idx1 + val1 = v + idx1 = j + elif v > val2: + val2 = v + idx2 = j + + origin = unit[i, idx1] + secv = unit[i, idx2] + alt = X[i, idx1] / (pieces[idx1] + 1) + delta = origin - (alt if alt > secv else secv) + deltas[idx1] += delta * stage_weights[i] if np.any( + delta) != 0 else stage_weights[i] + + max_idx = np.argmax(deltas) + pieces[max_idx] += 1 + for i in range(n_stage): + unit[i, max_idx] = X[i, max_idx] / pieces[max_idx] + + # Compute max load + max_load = 0.0 + for j in range(N): + total = 0.0 + for i in range(n_stage): + total += unit[i, j] + if total > max_load: + max_load = total + + return pieces + + +@njit +def jsq_placement(X, pieces, M, stage_weights): + n_stage, N = X.shape + total_piece = pieces.sum() + num_per_group = total_piece // M + + # 1. Compute unit_hotness + unit_hotness = np.empty((n_stage, N), dtype=np.float32) + for i in range(N): + if pieces[i] > 0: + for s in range(n_stage): + unit_hotness[s, i] = X[s, i] / pieces[i] + else: + for s in range(n_stage): + unit_hotness[s, i] = 0.0 + + # 2. Sort by total hotness + scores = np.zeros(N, dtype=np.float32) + for i in range(N): + for s in range(n_stage): + scores[i] += unit_hotness[s, i] + idx = np.argsort(-scores) + + # 3. Initialization + loads = np.zeros((n_stage, M), dtype=np.float32) + dev_phy_exp_n = np.zeros(M, dtype=np.int32) + deployment = -np.ones((M, num_per_group), dtype=np.int32) + dep_ptr = np.zeros(M, dtype=np.int32) + + # 4. Main loop + for t in range(N): + i = idx[t] + used_device = list() + for _ in range(pieces[i]): + # 4.1 Construct w vector + w = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + w[s] = unit_hotness[s, i] + + # 4.2 Compute stage-level maximum load + stage_max = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + max_val = loads[s, 0] + for k in range(1, M): + if loads[s, k] > max_val: + max_val = loads[s, k] + stage_max[s] = max_val + + # 4.3 Compute denominator + denom = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + sum_tmp = 0.0 + for j in range(M): + sum_tmp += loads[s, j] + w[s] + denom[s] = sum_tmp / M + 1e-2 + + # 4.4 Find best device j + best_j = -1 + best_val = 1e30 + for j in range(M): + if dev_phy_exp_n[j] >= num_per_group: + continue + if j in used_device: + continue + score = 0.0 + for s in range(n_stage): + tmp_sj = loads[s, j] + w[s] + numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s] + score += stage_weights[s] * (numer_sj / denom[s]) + if score < best_val: + best_val = score + best_j = j + if best_j == -1: + continue + + used_device.append(best_j) + + # 4.5 Update status + for s in range(n_stage): + loads[s, best_j] += w[s] + ptr = dep_ptr[best_j] + deployment[best_j, ptr] = i + dep_ptr[best_j] += 1 + dev_phy_exp_n[best_j] += 1 + + # Handle remaining -1 values: fill with random elements from range(N) not in current column + for rank in range(M): + for col in range(num_per_group): + if deployment[rank, col] == -1: + # Get elements already in current column + current_rank_elements = set(deployment[rank, :]) + # Filter elements from range(N) not in current column + available = [ + x for x in range(N) if x not in current_rank_elements + ] + # Randomly select an available element to fill + if len(available) > 0: + rand_idx = np.random.randint(0, len(available)) + deployment[rank, col] = available[rand_idx] + elif N > 0: + # All unique experts are already in this rank's column, so we can pick any expert randomly. + deployment[rank, col] = np.random.randint(0, N) + + return deployment + + +@njit +def slice_values(X, pieces): + total_len = 0 + for i in range(X.shape[0]): + total_len += pieces[i] + result = np.empty(total_len, dtype=np.float32) + idx = 0 + for i in range(X.shape[0]): + val = X[i] / pieces[i] + for _ in range(pieces[i]): + result[idx] = val + idx += 1 + return result + + +@njit +def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, + simulated_deployment, stage_weights): + n_stage, N = X.shape + num_group = P // M + + X_all = np.zeros(N, dtype=np.float32) + for i in range(n_stage): + for j in range(N): + X_all[j] += X[i, j] + + sort_idx = np.argsort(np.negative(X_all)) + X_sorted = X[:, sort_idx] + + unit_load = np.empty(N, dtype=np.float32) + for j in range(N): + unit_load[j] = X_all[j] / simulated_pieces[j] + + flat_deployment = simulated_deployment.reshape(-1) + simulated_load = np.zeros(M, dtype=np.float32) + for i in range(flat_deployment.shape[0]): + simulated_load[i // (flat_deployment.shape[0] // + M)] += unit_load[flat_deployment[i]] + + slice_vals = slice_values(X_all, simulated_pieces) + sorted_slices = np.sort(slice_vals)[::-1] + simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M + + cumulative_slices_used = np.zeros(N, dtype=np.int32) + acc = 0 + for i in range(N): + acc += simulated_pieces[sort_idx[i]] + cumulative_slices_used[i] = acc + + group_boundary_indices = np.zeros(num_group, dtype=np.int32) + for i in range(1, num_group + 1): + for j in range(N): + if cumulative_slices_used[j] >= i * M: + group_boundary_indices[i - 1] = j + break + + slices_used_per_group = np.zeros(num_group, dtype=np.int32) + slices_used_per_group[0] = group_boundary_indices[0] + for i in range(1, num_group): + slices_used_per_group[ + i] = group_boundary_indices[i] - group_boundary_indices[i - 1] + slices_used_per_group = M - slices_used_per_group + + loads = np.zeros(M, dtype=np.float32) + pieces = np.zeros(N, dtype=np.int32) + num_remain_slice = P - N + current_idx = 0 + + for g in range(num_group): + window = X_sorted[:, current_idx:current_idx + 2 * M] + low = max(0, current_idx + M - N) + high = min(num_remain_slice, M - 1) + + while (high - low) > 1: + mid = int((high + low) // 2) + keep = M - mid + current_group = window[:, :keep] + current_pieces = compute_piece_counts(current_group, M, + stage_weights) + current_pieces = np.maximum(current_pieces, 1) + current_slice = slice_values(current_group.sum(0), current_pieces) + current_slice_sorted = np.sort(current_slice) + current_loads = loads + current_slice_sorted + current_max: np.float32 = np.max(current_loads) + current_min: np.float32 = np.min(current_loads) + current_slope = (current_max - current_min) / M + next_slope: np.float32 = np.max(simulated_slopes[current_idx + + keep:]) + + if abs(current_slope) > abs(next_slope): + low = mid + else: + high = mid + + S = high + keep = M - S + current_group = window[:, :keep] + current_pieces = compute_piece_counts(current_group, M, stage_weights) + + for i in range(keep): + pieces[sort_idx[current_idx + i]] = current_pieces[i] + + current_slice = slice_values(current_group.sum(0), current_pieces) + current_slice_sorted = np.sort(current_slice) + loads += current_slice_sorted + loads = np.sort(loads)[::-1] + + current_idx += keep + num_remain_slice -= S + + return pieces + + +@njit +def compute_objective(deployment, X, pieces): + M, P = deployment.shape + loads = np.zeros(M) + + for i in range(M): + for j in range(P): + expert = deployment[i, j] + if pieces[expert] == 0: + continue + loads[i] += X[expert] / pieces[expert] + + mean_load = np.mean(loads) + max_load: np.float32 = np.max(loads) + obj = max_load / mean_load + return obj, loads + + +@njit +def auto_fix_new_placement(old_placement, new_placement): + """ + Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both + old_placement and new_placement remain in their original positions from old_placement. + New elements (unique to new_placement) will fill the remaining empty positions. + + Args: + old_placement: Old deployment matrix with shape (num_ranks, num_experts) + new_placement: New deployment matrix to be fixed, must have the same shape as old_placement + + Returns: + fixed_new: adjusted version of the new_placement matrix + """ + num_ranks, num_experts = old_placement.shape + fixed_new = np.empty_like(new_placement) + + max_expert_old = old_placement.max() if num_experts > 0 else 0 + max_expert_new = new_placement.max() if num_experts > 0 else 0 + max_expert = max(max_expert_old, max_expert_new) + + for rank_id in range(num_ranks): + old_row = old_placement[rank_id] + new_row = new_placement[rank_id] + + index_array = np.full((max_expert + 1, num_experts), + -1, + dtype=np.int32) + count_array = np.zeros(max_expert + 1, dtype=np.int32) + + for idx in range(num_experts): + val = old_row[idx] + if val >= 0 and val <= max_expert: + pos = count_array[val] + index_array[val, pos] = idx + count_array[val] += 1 + + old_counter = np.zeros(max_expert + 1, dtype=np.int32) + for idx in range(num_experts): + val = old_row[idx] + if val >= 0 and val <= max_expert: + old_counter[val] += 1 + + retain_elements = np.empty(num_experts, dtype=new_placement.dtype) + new_elements = np.empty(num_experts, dtype=new_placement.dtype) + retain_ptr = 0 + new_ptr = 0 + + for val in new_row: + if val >= 0 and val <= max_expert and old_counter[val] > 0: + retain_elements[retain_ptr] = val + retain_ptr += 1 + old_counter[val] -= 1 + else: + new_elements[new_ptr] = val + new_ptr += 1 + + current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype) + + for i in range(retain_ptr): + val = retain_elements[i] + if val >= 0 and val <= max_expert: + pos = count_array[val] - 1 + if pos >= 0: + idx = index_array[val, pos] + current_fixed[idx] = val + count_array[val] -= 1 + + empty_indices = np.empty(num_experts, dtype=np.int32) + empty_ptr = 0 + for idx in range(num_experts): + if current_fixed[idx] == -1: + empty_indices[empty_ptr] = idx + empty_ptr += 1 + + for i in range(new_ptr): + if i < empty_ptr: + current_fixed[empty_indices[i]] = new_elements[i] + + fixed_new[rank_id] = current_fixed + + return fixed_new + + +class FlashLB(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + self.par_history: Dict[int, float] = {} + self.hotness_window: Dict[int, deque[float]] = {} + self.max_stage_window = (config.max_stage_window if hasattr( + config, "max_stage_window") else 1) + self.buffer_expert_layer_num = ( + config.buffer_expert_layer_num if hasattr( + config, "buffer_expert_layer_num") else 58) + self.threshold_ratio = (config.threshold_ratio if hasattr( + config, "threshold_ratio") else 0) + + def compute_expert_hotness(self, num_of_expert: int, + deployment: np.ndarray, rank_load: np.ndarray): + hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) + deployment_flat = deployment.ravel() + rank_load_flat = rank_load.ravel() + np.add.at(hotness, deployment_flat, rank_load_flat) + return hotness + + def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray): + n_stage, N = hotness.shape + if np.any(deployment < 0): + print(f"Invalid deployment with negative values: {deployment}") + raise ValueError("Deployment table contains negative values.") + counts = np.bincount(deployment.reshape(-1), minlength=N) + unit_hotness = np.divide(hotness, + counts, + out=np.zeros_like(hotness, dtype=float), + where=counts != 0) + stage_par = np.zeros(n_stage) + for i in range(n_stage): + stage_load = unit_hotness[i][deployment].sum(-1) + stage_par[i] = stage_load.max() / stage_load.mean() + return stage_par.mean() + + def group_based_adaptive_bloating(self, + X, + P, + M, + stage_weights=None, + recorsive=False): + n_stage, N = X.shape + if stage_weights is None: + stage_weights = np.ones(n_stage, dtype=np.float32) + + if recorsive: + ( + simulated_deployment, + simulated_pieces, + ) = self.group_based_adaptive_bloating(X, + P, + M, + stage_weights, + recorsive=False) + else: + simulated_pieces = compute_piece_counts(X, P, stage_weights) + simulated_deployment = jsq_placement(X, simulated_pieces, M, + stage_weights) + + pieces = group_based_adaptive_bloating_kernel( + X.astype(np.float32), + P, + M, + simulated_pieces.astype(np.int32), + simulated_deployment.astype(np.int32), + stage_weights.astype(np.float32), + ) + + deployment = jsq_placement(X, pieces, M, stage_weights) + + X_all = X.sum(0) + unit_load = np.divide(X_all, + pieces, + out=np.zeros_like(X_all, dtype=float), + where=pieces != 0) + load = unit_load[deployment].sum(-1) + + sim_unit_load = X_all / simulated_pieces + sim_load = sim_unit_load[simulated_deployment].sum(-1) + + if load.max() > sim_load.max(): + return simulated_deployment, simulated_pieces + return deployment, pieces + + def need_update(self, current_par, layer_id=0): + threshold = self.par_history.get(layer_id, 0.0) + return current_par >= self.threshold_ratio * threshold + + def compute_stage_weight(self, hotness): + n_stage = hotness.shape[0] + stage_weights = np.zeros(n_stage) + for i in range(n_stage): + stage_weights[i] = hotness[i].sum() + + stage_weights = stage_weights / stage_weights.max() + return stage_weights + + def rebalance_layer(self, deployment, hotness, layer_id=0): + num_rank, expert_per_rank = deployment.shape + num_expert = np.unique(deployment.reshape(-1)).shape[0] + num_of_redundant_expert = num_rank * expert_per_rank - num_expert + + current_par = self.compute_rank_load(deployment, hotness) + + if not self.need_update(current_par, layer_id): + return deployment, current_par, current_par + + stage_weights = self.compute_stage_weight(hotness) + new_deployment, _ = self.group_based_adaptive_bloating( + hotness, + num_expert + num_of_redundant_expert, + num_rank, + stage_weights, + recorsive=False, + ) + if np.any(new_deployment < 0): + print(f"{new_deployment=}") + new_par = self.compute_rank_load(new_deployment, hotness) + + return new_deployment, new_par, current_par + + def register_hotness(self, deployment, rank_load, num_layer, num_expert): + for layer in range(num_layer): + if layer not in self.hotness_window: + self.hotness_window[layer] = deque( + maxlen=self.max_stage_window) + hotness = self.compute_expert_hotness(num_expert, + deployment[layer], + rank_load[layer]) + self.hotness_window[layer].append(hotness) + + def compress_by_avg_pooling_fast_nd(self, arr, m): + n, d = arr.shape + idx = (np.arange(n) * m // n) + result = np.zeros((m, d)) + counts = np.zeros((m, 1)) + np.add.at(result, idx, arr) + np.add.at(counts, idx, 1) + return result / counts + + def rebalance_experts(self, current_expert_table, expert_workload): + current_deployment = np.array(current_expert_table) + expert_workload = np.array(expert_workload) + expert_workload += 1 + num_layer = expert_workload.shape[0] + num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] + self.register_hotness(current_deployment, expert_workload, num_layer, + num_expert) + + new_deployment = current_deployment.copy() + + layers_need_update = np.arange(num_layer) + + new_par = np.zeros(layers_need_update.shape[0]) + current_par = np.zeros(layers_need_update.shape[0]) + for i, layer in enumerate(layers_need_update): + hotness = np.array(self.hotness_window[layer]) + if hotness.shape[0] > self.max_stage_window: + hotness = self.compress_by_avg_pooling_fast_nd( + hotness, self.max_stage_window) + + ( + new_deployment[layer], + new_par[i], + current_par[i], + ) = self.rebalance_layer(current_deployment[layer], + hotness, + layer_id=layer) + + priority = new_par / current_par + priority_idx = np.argsort(priority) + priority_idx = priority_idx[priority[priority_idx] < + 1][:self.buffer_expert_layer_num] + + if np.all(expert_workload == 1): + for _, layer in enumerate(layers_need_update): + self.hotness_window[layer].pop() + return False, np.array([], dtype=int), current_deployment + change = len(priority_idx) > 0 + if change: + for idx in priority_idx: + self.par_history[layers_need_update[idx]] = new_par[idx] + + layers_need_update = priority_idx + deployment = current_deployment + for layer in layers_need_update: + deployment[layer] = auto_fix_new_placement( + current_deployment[layer], new_deployment[layer]) + + return change, layers_need_update, deployment + + +def generate_layered_experts(num_layers=58, + layer_shape=(32, 9), + expert_min=0, + expert_max=255): + """ + Generate expert deployment matrix meeting the following conditions: + - Total of num_layers layers + - Each layer has shape layer_shape (32,9) + - Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer + + Args: + num_layers: Number of layers, default 58 + layer_shape: Shape of a single layer, default (32,9) + expert_min: Minimum expert ID, default 0 + expert_max: Maximum expert ID, default 255 + Returns: + torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1]) + """ + # 1. Basic parameter calculation + expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) + layer_total = layer_shape[0] * layer_shape[ + 1] # Total elements in a single layer: 32*9=288 + extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 + + # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) + assert layer_total >= expert_num, ( + f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, " + "cannot cover all experts") + + # 3. Generate layers one by one + layers = [] + for _ in range(num_layers): + # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) + full_experts = torch.arange(expert_min, + expert_max + 1, + dtype=torch.int64) # shape (256,) + + # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) + extra_experts = torch.randint(expert_min, + expert_max + 1, + size=(extra_slots, ), + dtype=torch.int64) # shape (32,) + + # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) + layer_flat = torch.cat([full_experts, extra_experts], + dim=0) # shape (288,) + # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) + shuffle_idx = torch.randperm(layer_flat.shape[0]) + layer_shuffled = layer_flat[shuffle_idx] + + # 3.4 Reshape to layer_shape (32,9) + layer = layer_shuffled.reshape(layer_shape) + layers.append(layer) + + # 4. Stack all layers to get the final tensor + return torch.stack(layers, dim=0) # shape (58,32,9) + + +def warm_up(): + exam_config = DynamicConfig() + exam_config.ep_worldsize = 32 + exam_config.num_die_per_host = 16 + algo = FlashLB(exam_config) + # Generate target tensor + expert_tensor = generate_layered_experts(num_layers=58, + layer_shape=(32, 9)) + + algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9))) diff --git a/vllm_npu/eplb/core/policy/policy_random.py b/vllm_npu/eplb/core/policy/policy_random.py new file mode 100644 index 0000000..558d653 --- /dev/null +++ b/vllm_npu/eplb/core/policy/policy_random.py @@ -0,0 +1,30 @@ +# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +import copy +import random + +from .policy_abstract import DynamicConfig, EplbPolicy + +random.seed(42) + + +class RandomLoadBalance(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + def rebalance_experts(self, current_expert_table, expert_workload): + new_table = copy.deepcopy(current_expert_table) + num_layers = len(current_expert_table) + + for i in range(num_layers): + # randomly choose two card + # indices = random.sample(range(num_card), 2) + indices = [3, 1] + + # swap redundant experts + expert_id_to_exchange = new_table[i][indices[0]][-1].clone() + new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1] + new_table[i][indices[1]][-1] = expert_id_to_exchange + + return 1, [-i for i in range(num_layers)], new_table diff --git a/vllm_npu/eplb/eplb_updator.py b/vllm_npu/eplb/eplb_updator.py new file mode 100644 index 0000000..fdbf819 --- /dev/null +++ b/vllm_npu/eplb/eplb_updator.py @@ -0,0 +1,209 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator. +import numpy +import torch +import torch.distributed as dist +import vllm.envs as envs +from vllm.logger import logger + +from vllm_npu.eplb.core.eplb_utils import EPLBParamUtils +from vllm_npu.eplb.core.eplb_worker import EplbProcess + + +class EplbUpdator: + + def __init__(self, ascend_config, loader, eplb_process: EplbProcess, + process): + self.ascend_config = ascend_config + self.init_eplb(self.ascend_config.expert_map_path, process) + self.eplb_loader = loader + self.eplb_process = eplb_process + self.shared_dict = self.eplb_process.shared_dict + + def set_adaptor(self, adaptor): + self.adaptor = adaptor + self.num_moe_layers = self.adaptor.num_moe_layers + self.global_expert_num = self.adaptor.global_expert_num + + def init_eplb(self, expert_map_path, process): + self.rank_id = dist.get_rank() + self.num_expert_load_gather = 10 + self.periodic_load_gather = True + self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update + EPLBParamUtils.check_iterations(self.num_iterations_eplb_update) + self.expert_map_path = expert_map_path + self.expert_map_record_path = self.ascend_config.expert_map_record_path + + try: + if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + except Exception: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + + self.expert_map_initialized = False + self.gate_eplb = self.ascend_config.gate_eplb + + self.reqs = [] + self.update_info_all = [] + + self.cur_iterations: torch.int64 = 0 + + self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations + EPLBParamUtils.check_iterations(self.num_wait_worker_iterations) + + self.process = process + + logger.info( + f"[ModelRunner] Launched EPLB process (pid={self.process.pid})") + + def update_iteration(self): + self.cur_iterations += 1 + if self.cur_iterations == (self.num_iterations_eplb_update + \ + self.num_wait_worker_iterations + self.num_moe_layers): + if self.expert_map_record_path is not None: + self.adaptor._export_tensor_to_file( + self.shared_dict["expert_maps"], + self.expert_map_record_path) + + self.adaptor.model.clear_all_moe_loads() + if not self.gate_eplb: + self.cur_iterations = 0 + + def get_update_info_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update + + self.num_wait_worker_iterations - 1) + + def wakeup_eplb_worker_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update - 1) + + def update_expert_weight_flag(self): + weight_update_counter = self.cur_iterations - ( + self.num_iterations_eplb_update + self.num_wait_worker_iterations) + return (weight_update_counter >= 0 + and weight_update_counter < self.num_moe_layers) + + def get_init_expert_map(self): + try: + if not self.expert_map_initialized: + self.shared_dict[ + "expert_maps"] = self.adaptor.get_init_expert_map_from_file( + self.num_moe_layers, self.expert_map_path) + self.expert_map_initialized = True + except Exception as e: + logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", + exc_info=True) + + def wakeup_eplb_worker(self): + self.eplb_process.planner_q.put(1) + + def forward_before(self): + if self.update_expert_weight_flag(): + (expert_send_info, expert_recv_info, updated_expert_map, + log2phy_map, layer_id) = self.update_info_all.pop(0) + log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) + self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) + updated_expert_map_this_rank = torch.from_numpy( + numpy.array(updated_expert_map)) + self.eplb_loader.generate_expert_d2d_transfer_task( + expert_send_info, expert_recv_info, + updated_expert_map_this_rank, + layer_id + self.adaptor.num_dense_layers) + + # set asynchronous stream for d2d expert weight update + self.reqs = [] + self.eplb_loader.asyn_expert_weight_transfer(self.reqs) + + def take_update_info_from_eplb_process(self): + # Batch after eplb process being triggered, get update info provided by eplb process + if self.get_update_info_flag(): + self.update_info_all = self.eplb_process.block_update_q.get() + + def forward_end(self): + if self.wakeup_eplb_worker_flag(): + self.compute_and_set_moe_load(is_clear=True) + self.wakeup_eplb_worker() + + if self.update_expert_weight_flag( + ) and self.expert_map_record_path is None: + self.eplb_loader.update_expert_map_and_weight(self.reqs) + + self.update_iteration() + + def compute_and_set_moe_load(self, is_clear=False): + local_load = self.adaptor.get_rank_expert_workload() + + self._gather_buffer = None + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug( + f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" + ) + else: + moe_load = local_load.unsqueeze(1) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug( + f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" + ) + return moe_load + + def warm_up_eplb(self): + + self.get_init_expert_map() + self.compute_and_set_moe_load() + + src_tensor = torch.empty((1, ), device=self.device) + self_rank = dist.get_rank() + + comm_op_list = [] + + for dst_rank in range(self.world_size): + if dst_rank == self_rank: + continue + comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + + for src_rank in range(self.world_size): + if src_rank == self_rank: + continue + comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank)) + if comm_op_list: + reqs = dist.batch_isend_irecv(comm_op_list) + + for req in reqs: + req.wait() + + def shutdown(self): + """ + Clean up the EPLB process. + """ + if self.process.is_alive(): + self.process.terminate() + self.process.join() + logger.info("[ModelRunner] EPLB process terminated") diff --git a/vllm_npu/eplb/utils.py b/vllm_npu/eplb/utils.py new file mode 100644 index 0000000..61e5735 --- /dev/null +++ b/vllm_npu/eplb/utils.py @@ -0,0 +1,77 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register. +import types + +import torch + + +def get_expert_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_map() + + +def get_log2phy_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_log2phy_map() + + +def get_all_expert_map(self, num_moe_layers): + all_loads = [] + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + for layer_id in range(num_moe_layers): + load_tensor = self.get_expert_map( + layer_id + num_dense_layers) # (num_experts_per_layer,) + all_loads.append(load_tensor) + + return torch.stack(all_loads, dim=0) + + +def get_all_moe_loads(self): + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + all_moe_loads = torch.stack( + [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ + for layer_id in range(self.num_moe_layers)], + dim=0 + ) + return all_moe_loads + + +def clear_all_moe_loads(self): + num_dense_layers = self.num_dense_layers if hasattr( + self, "num_dense_layers") else 0 + for layer_id in range(self.num_moe_layers): + self.model.layers[layer_id + + num_dense_layers].mlp.experts.clear_moe_load() + + +def model_register(model, model_config): + model.get_expert_map = types.MethodType(get_expert_map, model) + model.get_log2phy_map = types.MethodType(get_log2phy_map, model) + model.get_all_expert_map = types.MethodType(get_all_expert_map, model) + model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model) + model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model) + + config = model_config.hf_config + + if config.model_type == "qwen3_moe": + model.num_moe_layers = config.num_hidden_layers + elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": + model.num_dense_layers = config.first_k_dense_replace + model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers + else: + raise NotImplementedError("EPLB is not supported.") diff --git a/vllm_npu/lora/__init__.py b/vllm_npu/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/lora/lora_ops.py b/vllm_npu/lora/lora_ops.py new file mode 100644 index 0000000..58d0ea6 --- /dev/null +++ b/vllm_npu/lora/lora_ops.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def bgmv_shrink(inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + return torch.ops._C_ascend.bgmv_shrink( + inputs, + lora_a_weights, + lora_indices_tensor, + output_tensor, + scaling, + ) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + return torch.ops._C_ascend.bgmv_expand( + inputs, + lora_b_weights, + lora_indices_tensor, + output_tensor, + 0, + output_tensor.size(1), + ) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, output_tensor, + slice_offset, slice_size) + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, scaling) + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + return torch.ops._C_ascend.sgmv_expand( + inputs, + lora_b_weights, + lora_indices_tensor, + seq_len_tensor, + output_tensor, + 0, + output_tensor.size(1), + ) + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, slice_offset, + slice_size) diff --git a/vllm_npu/lora/punica_npu.py b/vllm_npu/lora/punica_npu.py new file mode 100644 index 0000000..ac07de8 --- /dev/null +++ b/vllm_npu/lora/punica_npu.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Optional, Tuple, Union + +import torch + +from vllm_npu.utils import is_310p + +if is_310p(): + from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +else: + from vllm_npu.lora.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + +from vllm_npu.lora.utils import refresh_all_lora_classes + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperNPU(PunicaWrapperBase): + """ + PunicaWrapperNPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + refresh_all_lora_classes() + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + + if buffer is None: + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + indices = self.sampler_indices + + bgmv_shrink(x, lora_a_stacked, buffer, indices, scale) + bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True) + + y = y.view_as(y_org) diff --git a/vllm_npu/lora/utils.py b/vllm_npu/lora/utils.py new file mode 100644 index 0000000..f57b201 --- /dev/null +++ b/vllm_npu/lora/utils.py @@ -0,0 +1,110 @@ +from typing import Optional + +import vllm +from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +from vllm.lora.layers.utils import _not_fully_sharded_can_replace + +from vllm_npu.ops.linear import (AscendColumnParallelLinear, + AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, + AscendRowParallelLinear) +from vllm_npu.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding + + +class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendColumnParallelLinear + + +class AscendMergedColumnParallelLinearWithLoRA( + MergedColumnParallelLinearWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendMergedColumnParallelLinear + + +class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendRowParallelLinear + + +class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is AscendVocabParallelEmbedding + + +class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA): + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is AscendQKVParallelLinear and len( + packed_modules_list) == 1 + + +class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is AscendQKVParallelLinear + and len(packed_modules_list) == 3) + + +def refresh_all_lora_classes(): + vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add( + AscendMergedColumnParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add( + AscendMergedQKVParallelLinearWithLoRA) diff --git a/vllm_npu/meta_registration.py b/vllm_npu/meta_registration.py new file mode 100644 index 0000000..3c57086 --- /dev/null +++ b/vllm_npu/meta_registration.py @@ -0,0 +1,105 @@ +import torch +from torch.library import Library + +# This file provides a template and registration utilities for writing "meta" implementations +# of custom operators in Python for the vllm_npu project. +# +# We offer two ways to implement meta implementations for custom ops: +# 1. Python meta implementation (as shown in this file): Write a Python function that +# takes the same arguments as your operator and returns empty tensors with the correct +# shapes and dtypes. This is useful for rapid prototyping and for ops that are only +# used in Python. +# 2. C++ meta implementation: You can also implement the meta function in C++ for better +# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp` +# for examples of C++ meta implementations and how to register them. +# +# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which +# is essential for supporting `torch.compile` and aclgraph. + +# How to add a new meta implementation in Python: +# ------------------------------------- +# 1. Write a Python function that takes the same arguments as your operator, and returns +# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes. +# Do NOT perform any real computation or allocate device memory. +# +# 2. Register your meta function using `register_meta_if_necessary`, providing: +# - The namespace (usually "_C_ascend" for custom ops) +# - The operator name (as registered in C++) +# - The Python meta function +# - (Optional) The overload name, if your op has overloads +# +# 3. The registration utility will check if a meta implementation already exists for your op, +# and only register if necessary. This avoids duplicate registrations. +# +# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask. +# +# 5. When developing new custom ops, always provide a meta implementation to enable tracing, +# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile` +# and aclgraph. +# +# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors + +lib = Library("_C_ascend", "IMPL") + + +def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): + if overload != "": + op_name = op_name + "." + overload + schema_to_find = ns + "::" + op_name + meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key( + "Meta") + if schema_to_find in meta_impl_list: + return + lib.impl(op_name, fn, "Meta") + + +def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool): + + num_tokens = positions.numel() + query_hidden_size = query.numel() // num_tokens + key_hidden_size = key.numel() // num_tokens + num_heads = query_hidden_size // head_size + num_kv_heads = key_hidden_size // head_size + + query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size) + key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size) + return query_dst, key_dst + + +def get_masked_input_and_mask_meta(input: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int): + + masked_input = torch.empty_like(input) + mask = torch.empty_like(input).to(torch.bool) + + return masked_input, mask + + +def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, + indices: torch.Tensor, y: torch.Tensor, slice_offset: int, + slice_size: int): + + y_out = torch.empty_like(y) + return y_out + + +def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, + lora_indices: torch.Tensor, seq_len: torch.Tensor, + y: torch.Tensor, slice_offset: int, slice_size: int): + + y_out = torch.empty_like(y) + return y_out + + +register_meta_if_necessary("_C_ascend", "rotary_embedding", + rotary_embedding_meta) +register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", + get_masked_input_and_mask_meta) +register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta) diff --git a/vllm_npu/models/__init__.py b/vllm_npu/models/__init__.py new file mode 100644 index 0000000..447351a --- /dev/null +++ b/vllm_npu/models/__init__.py @@ -0,0 +1,48 @@ +from vllm import ModelRegistry + +import vllm_npu.envs as envs_ascend + + +def register_model(): + ModelRegistry.register_model( + "Qwen2VLForConditionalGeneration", + "vllm_npu.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") + + ModelRegistry.register_model( + "Qwen3VLMoeForConditionalGeneration", + "vllm_npu.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration" + ) + + ModelRegistry.register_model( + "Qwen3VLForConditionalGeneration", + "vllm_npu.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration" + ) + + if envs_ascend.USE_OPTIMIZED_MODEL: + ModelRegistry.register_model( + "Qwen2_5_VLForConditionalGeneration", + "vllm_npu.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" + ) + ModelRegistry.register_model( + "Qwen2_5OmniModel", + "vllm_npu.models.qwen2_5_omni_thinker:AscendQwen2_5OmniThinkerForConditionalGeneration" + ) + else: + ModelRegistry.register_model( + "Qwen2_5_VLForConditionalGeneration", + "vllm_npu.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" + ) + + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_npu.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM") + + # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization + # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. + ModelRegistry.register_model( + "PanguProMoEForCausalLM", + "vllm_npu.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" + ) + ModelRegistry.register_model( + "Qwen3NextForCausalLM", + "vllm_npu.models.qwen3_next:CustomQwen3NextForCausalLM") diff --git a/vllm_npu/models/deepseek_v3_2.py b/vllm_npu/models/deepseek_v3_2.py new file mode 100644 index 0000000..bca5da5 --- /dev/null +++ b/vllm_npu/models/deepseek_v3_2.py @@ -0,0 +1,633 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Dict, Iterable, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (divide, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, split_tensor_along_last_dim, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, + DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, + get_spec_layer_idx_from_weight_name) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.models.layers.sfa import (AscendSFAModules, + AscendSparseFlashAttention, Indexer) +from vllm_npu.ops.common_fused_moe import AscendFusedMoE +from vllm_npu.ops.linear import AscendLinearBase + + +@support_torch_compile +class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Rewrite this init func mainly for removing cuda-hard code + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + assert hasattr(config, "index_topk") + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device=current_platform.device_type) + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + +class CustomDeepseekV2RowParallelLinear(RowParallelLinear): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = nn.Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + self.update_param_tp_status() + + def forward( + self, + input_, + is_prefill=True, + is_force_scatter=False + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + sfa_modules = AscendSFAModules( + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + rotary_emb=self.rotary_emb, + indexer=self.indexer) + + self.sfa_attn = AscendSparseFlashAttention( + self.hidden_size, + self.enable_shared_expert_dp, + self.debug_layer_idx, + self.first_k_dense_replace, + self.tp_size, + sfa_modules, + self.num_local_heads, + self.scaling, + self.layers, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.q_lora_rank, + self.qk_nope_head_dim, + self.qk_head_dim, + self.v_head_dim, + cache_config, + quant_config, + prefix, + ) + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata) + + +class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__(self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer=None) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + self.layers = config.num_hidden_layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = CustomDeepseekV2SFAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + if self.mlp.gate.e_score_correction_bias is not None: + self.mlp.gate.e_score_correction_bias.data = ( + self.mlp.gate.e_score_correction_bias.data.to( + dtype=torch.get_default_dtype())) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + self.first_k_dense_replace = config.first_k_dense_replace + self.tp_group = get_tp_group().device_group + + +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = hasattr( + config, "q_lora_rank") and config.q_lora_rank is not None + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = AscendDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights: list[Any] = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + # NOTE: This `load_weights` is mainly copied from + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 + # to fix CI, and it is different from the implementation in main + # TODO: support eplb style load_weights + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "module" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): + pass + + +DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__ diff --git a/vllm_npu/models/layers/__init__.py b/vllm_npu/models/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/models/layers/mla.py b/vllm_npu/models/layers/mla.py new file mode 100644 index 0000000..b50422e --- /dev/null +++ b/vllm_npu/models/layers/mla.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import nn +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.mla import MLAModules +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.attention import Attention + from vllm.model_executor.layers.mla import \ + MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper +else: + from vllm.attention.layer import MLAAttention + from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper + + +# TODO(whx): adapt v0.11.0 and DSA +class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): + + def __init__( + self, + hidden_size: int, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + mla_modules: MLAModules, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.prefix = prefix + hf_config = get_current_vllm_config().model_config.hf_config + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + self.first_k_dense_replace = hf_config.first_k_dense_replace + self.tp_size = get_tensor_model_parallel_world_size() + self.layers = hf_config.num_hidden_layers + + if vllm_version_is("0.11.0"): + self.mla_attn = Attention( + num_heads=num_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scale, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + qk_head_dim=self.qk_head_dim, + rotary_emb=mla_modules.rotary_emb, + fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, + q_b_proj=mla_modules.q_b_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + kv_b_proj=mla_modules.kv_b_proj, + o_proj=mla_modules.o_proj, + ) + else: + self.mla_attn = MLAAttention( + num_heads=self.num_heads, + scale=scale, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + kv_b_proj=mla_modules.kv_b_proj, + use_sparse=mla_modules.is_sparse, + indexer=mla_modules.indexer, + # extra args + qk_head_dim=self.qk_head_dim, + rotary_emb=mla_modules.rotary_emb, + fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, + q_b_proj=mla_modules.q_b_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + o_proj=mla_modules.o_proj, + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + need_gather_q_kv = get_forward_context().sp_enabled + output_shape = hidden_states.shape + # FIXME: This does not seem right, should make sure the buffer is fixed + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, + self.prefix) + output = output.view(-1, output_shape[-1]) + return output + + +def mla_forward( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata + kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] + self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states, + kv_cache, attn_metadata, need_gather_q_kv, + output) + return + + +def mla_forward_fake( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mla_forward", + op_func=mla_forward, + mutates_args=["output"], + fake_impl=mla_forward_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_npu/models/layers/sfa.py b/vllm_npu/models/layers/sfa.py new file mode 100644 index 0000000..7c8fb27 --- /dev/null +++ b/vllm_npu/models/layers/sfa.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.mla import MultiHeadLatentAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op + + +@dataclass +class AscendSFAModules: + q_a_proj: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: torch.nn.Module + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + o_proj: torch.nn.Module + rotary_emb: torch.nn.Module + indexer: torch.nn.Module + + +class AscendSparseFlashAttention(MultiHeadLatentAttention): + + def __init__( + self, + hidden_size: int, + enable_shared_expert_dp: bool, + debug_layer_idx: int, + first_k_dense_replace: int, + tp_size: int, + sfa_modules: AscendSFAModules, + num_local_heads: int, + scaling: float, + layers: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + q_lora_rank: Optional[int], + qk_nope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.enable_shared_expert_dp = enable_shared_expert_dp + self.debug_layer_idx = debug_layer_idx + self.first_k_dense_replace = first_k_dense_replace + self.tp_size = tp_size + self.num_local_heads = num_local_heads + self.layers = layers + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.prefix = prefix + + self.sfa_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sparse=True, + # SFA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=sfa_modules.rotary_emb, + q_a_proj=sfa_modules.q_a_proj, + q_a_layernorm=sfa_modules.q_a_layernorm, + q_proj=sfa_modules.q_proj, + kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa, + kv_a_layernorm=sfa_modules.kv_a_layernorm, + kv_b_proj=sfa_modules.kv_b_proj, + o_proj=sfa_modules.o_proj, + indexer=sfa_modules.indexer) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # Simulate all gather to calculate output shape + num_tokens = num_tokens * self.tp_size + need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + # FIXME: This does not seem right, should make sure the buffer is fixed + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output, + self.prefix) + output = output.view(-1, output_shape[-1]) + return output + + +def sfa_forward( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if forward_context.attn_metadata: + attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name] + else: + attn_metadata = forward_context.attn_metadata + kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine] + self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + return + + +class Indexer(nn.Module): + + def __init__(self, + config, + dim: int = 7168, + n_heads: int = 64, + head_dim: int = 128, + index_topk: int = 2048, + q_lora_rank: int = 1536, + rope_head_dim: int = 64, + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = ""): + super().__init__() + + self.dim: int = dim # 7168 + self.n_heads: int = n_heads # 64 + self.head_dim: int = head_dim # 128 + self.rope_head_dim: int = rope_head_dim # 64 + self.index_topk: int = index_topk # 2048 + self.q_lora_rank: int = q_lora_rank # 1536 + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + return_bias=False, + ) + self.wk = ReplicatedLinear( + self.dim, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + return_bias=False, + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.weights_proj", + return_bias=False, + ) + self.k_norm = nn.LayerNorm(self.head_dim) + self.softmax_scale = self.head_dim**-0.5 + + def forward(self): + return + + +def sfa_forward_fake( + hidden_states: torch.Tensor, + need_gather_q_kv: bool, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="sfa_forward", + op_func=sfa_forward, + mutates_args=["output"], + fake_impl=sfa_forward_fake, + dispatch_key="PrivateUse1", +) diff --git a/vllm_npu/models/qwen2_5_omni_thinker.py b/vllm_npu/models/qwen2_5_omni_thinker.py new file mode 100644 index 0000000..b705a57 --- /dev/null +++ b/vllm_npu/models/qwen2_5_omni_thinker.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/qwen2_5_vl.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import \ + Qwen2_5OmniThinkerConfig +from vllm.config import VllmConfig +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerForConditionalGeneration, + Qwen2_5OmniThinkerMultiModalProcessor, Qwen2_5OmniThinkerProcessingInfo) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + +from vllm_npu.models.qwen2_5_vl import AscendQwen2_5_VisionTransformer + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder) +class AscendQwen2_5OmniThinkerForConditionalGeneration( + Qwen2_5OmniThinkerForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen2_5OmniThinkerConfig = vllm_config.model_config.hf_config.thinker_config + quant_config = vllm_config.quant_config + # The following code reuse AscendQwen2_5_VisionTransformer from Qwen2_5_VL, + # which does not import any model strcut difference. And will not impact + # the modeling files removing. + self.visual = AscendQwen2_5_VisionTransformer( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) diff --git a/vllm_npu/models/qwen2_5_vl.py b/vllm_npu/models/qwen2_5_vl.py new file mode 100644 index 0000000..118ec51 --- /dev/null +++ b/vllm_npu/models/qwen2_5_vl.py @@ -0,0 +1,628 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/qwen2_5_vl.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from einops import rearrange +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer, + Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + +from vllm_npu.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz + +MIN_PAD_SIZE = 64 # min_size to pad weight +MAX_PAD_SIZE = 128 # max_size to pad weight + + +class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + self.embed_dim = embed_dim + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head + if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: + self.hidden_size_per_attention_head = MAX_PAD_SIZE + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + q = torch_npu.npu_rotary_mul(q, cos, sin) + k = torch_npu.npu_rotary_mul(k, cos, sin) + + q, k, v = [ + rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + context_layer = torch.empty_like(q) + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.origin_hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer) + + context_layer = rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, + quant_config, prefix) + self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) + return x + + +class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__(dim, theta) + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.inv_freq = inv_freq + + +class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + interleaved=False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix) + norm_layer = partial(RMSNorm, eps=norm_eps) + self.interleaved = interleaved + self.enable_pad = False + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // + 2) + self.patch_embed = AscendQwen2_5_VisionPatchEmbed( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + act_fn = get_act_and_mul_fn(vision_config.hidden_act) + self.blocks = nn.ModuleList([ + AscendQwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=act_fn, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: + self.enable_pad = True + self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head + self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 + self.half_pad_hidden_size_per_attention_head = ( + MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 + self.hidden_size_per_attention_head = MAX_PAD_SIZE + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + if self.enable_pad: + cos = torch.nn.functional.pad( + cos, (0, self.half_pad_hidden_size_per_attention_head)) + sin = torch.nn.functional.pad( + sin, (0, self.half_pad_hidden_size_per_attention_head)) + + if not self.interleaved: + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + else: + cos_new = rearrange(torch.stack((cos, cos), dim=-1), + "... d two -> ...(d two)", + two=2) + sin_new = rearrange(torch.stack((sin, sin), dim=-1), + "... d two -> ...(d two)", + two=2) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def pad_qkv_bias(self, bias): + first_half = bias.reshape( + -1, 3, self.origin_hidden_size_per_attention_head + )[:, :, :self.half_origin_hidden_size_per_attention_head] + second_half = bias.reshape( + -1, 3, self.origin_hidden_size_per_attention_head + )[:, :, self.half_origin_hidden_size_per_attention_head:] + first_half_padded = torch.nn.functional.pad( + first_half, (0, self.half_pad_hidden_size_per_attention_head)) + second_half_padded = torch.nn.functional.pad( + second_half, (0, self.half_pad_hidden_size_per_attention_head)) + bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) + bias_final = bias_padded.reshape(-1) + return bias_final + + def pad_qkv_weight(self, data): + qkv_weight_first_half = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size + )[:, :, :self.half_origin_hidden_size_per_attention_head, :] + qkv_weight_second_half = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size + )[:, :, self.half_origin_hidden_size_per_attention_head:, :] + + qkv_weight_first_half_padded = torch.nn.functional.pad( + qkv_weight_first_half, + (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) + qkv_weight_second_half_padded = torch.nn.functional.pad( + qkv_weight_second_half, + (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) + qkv_weight_padded = torch.cat( + [qkv_weight_first_half_padded, qkv_weight_second_half_padded], + dim=2) + qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) + + if is_enable_nz(qkv_weight_final.dtype): + qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( + qkv_weight_final) + qkv_weight_final_copy = torch_npu.npu_format_cast( + qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) + return qkv_weight_final_copy + + return qkv_weight_final + + def pad_proj_weight(self, data): + out_weight = torch.nn.functional.pad( + data.reshape(self.hidden_size, -1, + self.half_origin_hidden_size_per_attention_head), + (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( + self.hidden_size, -1) + + if is_enable_nz(out_weight.dtype): + out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) + out_weight_copy = torch_npu.npu_format_cast( + out_weight_copy, ACL_FORMAT_FRACTAL_ND) + return out_weight_copy + + return out_weight + + def pad_qkv_weight_scale_offset(self, data): + reshaped_data = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, 1) + data1 = reshaped_data[:, :, :self. + half_origin_hidden_size_per_attention_head, :] + data2 = reshaped_data[:, :, self. + half_origin_hidden_size_per_attention_head:, :] + data1_paded = torch.nn.functional.pad( + data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, + 0, 0, 0)) + data2_paded = torch.nn.functional.pad( + data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, + 0, 0, 0)) + res = torch.cat([data1_paded, data2_paded], dim=2) + res = res.reshape(-1, 1) + return res + + def pad_qkv_deq_scale_quant_bias(self, data): + reshaped_data = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head) + data1 = reshaped_data[:, :, :self. + half_origin_hidden_size_per_attention_head] + data2 = reshaped_data[:, :, + self.half_origin_hidden_size_per_attention_head:] + + data1_paded = torch.nn.functional.pad( + data1, (0, self.half_pad_hidden_size_per_attention_head)) + data2_paded = torch.nn.functional.pad( + data2, (0, self.half_pad_hidden_size_per_attention_head)) + + res = torch.cat([data1_paded, data2_paded], dim=2) + res = res.reshape(-1) + return res + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ("mlp.gate_up_proj.", "mlp.gate_proj.", 0), + ("mlp.gate_up_proj.", "mlp.up_proj.", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + if self.enable_pad and shard_id == "v": + if "attn.qkv.weight" in name: + param.data = self.pad_qkv_weight(param.data) + if "attn.qkv.bias" in name: + param.data = self.pad_qkv_bias(param.data) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if ("attn.proj.weight_scale" in name or + "attn.proj.weight_offset" in name) and self.enable_pad: + continue + elif ("attn.proj.deq_scale" in name + or "attn.proj.quant_bias" in name) and self.enable_pad: + continue + elif ("attn.qkv.weight_scale" in name + or "attn.qkv.weight_offset" in name) and self.enable_pad: + param.data = self.pad_qkv_weight_scale_offset(param.data) + elif ("attn.qkv.deq_scale" in name + or "attn.qkv.quant_bias" in name) and self.enable_pad: + param.data = self.pad_qkv_deq_scale_quant_bias(param.data) + elif ("attn.proj.weight" in name) and self.enable_pad: + param.data = self.pad_proj_weight(param.data) + elif ("attn.qkv.weight" in name) and self.enable_pad: + param.data = self.pad_qkv_weight(param.data) + elif ("attn.qkv.bias" in name) and self.enable_pad: + param.data = self.pad_qkv_bias(param.data) + loaded_params.add(name) + return loaded_params + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, + 0]).cpu().to(torch.int32) + + # patchify + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=x.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) + seq_len, _ = x.size() + x = x.reshape(seq_len // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + x = x[window_index, :, :] + x = x.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + # transformers + x = x.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) + + # adapter + x = self.merger(x) + reverse_indices = torch.argsort(window_index) + x = x[reverse_indices, :] + return x + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class AscendQwen2_5_VLForConditionalGeneration( + Qwen2_5_VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.visual = AscendQwen2_5_VisionTransformer( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return image_embeds.split(sizes.tolist()) + + def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return video_embeds.split(sizes.tolist()) + + def _get_text_embeddings( + self, + input_ids: torch.Tensor, + get_input_embeddings: Callable[[torch.Tensor], torch.Tensor], + *, + is_multimodal: Optional[torch.Tensor], + handle_oov_mm_token: bool, + ) -> torch.Tensor: + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + + return get_input_embeddings(input_ids) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + """ + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. + """ + from vllm.model_executor.models.utils import \ + _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + is_multimodal=is_multimodal, + multimodal_embeddings=multimodal_embeddings, + ) diff --git a/vllm_npu/models/qwen2_5_vl_without_padding.py b/vllm_npu/models/qwen2_5_vl_without_padding.py new file mode 100644 index 0000000..4458260 --- /dev/null +++ b/vllm_npu/models/qwen2_5_vl_without_padding.py @@ -0,0 +1,780 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from einops import rearrange +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.model_executor.models.interfaces import MultiModalEmbeddings + +try: + from transformers.models.qwen3_vl.configuration_qwen3_vl import \ + Qwen3VLConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ + Qwen3VLMoeConfig +except ImportError: + pass +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, + get_act_and_mul_fn) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, + Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, + Qwen2_5_VLProcessingInfo) + +try: + from vllm.model_executor.models.qwen3_vl import ( + Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) + from vllm.model_executor.models.qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) +except ImportError: + Qwen3_VisionBlock = object + Qwen3_VisionPatchEmbed = object + Qwen3_VisionTransformer = object + Qwen3VLDummyInputsBuilder = object + Qwen3VLForConditionalGeneration = object + Qwen3VLMultiModalProcessor = object + Qwen3VLProcessingInfo = object + Qwen3VLMoeForConditionalGeneration = object + Qwen3VLMoeProcessingInfo = object +from vllm.model_executor.models.utils import (WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix) +from vllm.multimodal import MULTIMODAL_REGISTRY + +from vllm_npu.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding + + +class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + self.embed_dim = embed_dim + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + q = torch_npu.npu_rotary_mul(q, cos, sin) + k = torch_npu.npu_rotary_mul(k, cos, sin) + + q, k, v = [ + rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + context_layer = torch.empty_like(q) + + # operator requires pta version >= 2.5.1.dev20250226 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer) + + context_layer = rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock): + + def __init__(self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, + quant_config, prefix) + self.attn = AscendQwen2_5_VisionAttention_Without_Padding( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) + return x + + +class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer + ): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + interleaved=False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix) + norm_layer = partial(RMSNorm, eps=norm_eps) + self.interleaved = interleaved + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // + 2) + self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + act_fn = get_act_and_mul_fn(vision_config.hidden_act) + self.blocks = nn.ModuleList([ + AscendQwen2_5_VisionBlock_Without_Padding( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=act_fn, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + + if not self.interleaved: + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + else: + cos_new = rearrange(torch.stack((cos, cos), dim=-1), + "... d two -> ...(d two)", + two=2) + sin_new = rearrange(torch.stack((sin, sin), dim=-1), + "... d two -> ...(d two)", + two=2) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, + 0]).cpu().to(torch.int32) + + # patchify + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=x.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) + seq_len, _ = x.size() + x = x.reshape(seq_len // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + x = x[window_index, :, :] + x = x.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + # transformers + x = x.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) + + # adapter + x = self.merger(x) + reverse_indices = torch.argsort(window_index) + x = x[reverse_indices, :] + return x + + +class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) + x = x + self.proj.bias + return x + + +class AscendQwen3_VisionBlock(Qwen3_VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, + quant_config, prefix, use_data_parallel) + self.attn = AscendQwen2_5_VisionAttention_Without_Padding( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix, + use_data_parallel) + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + self.patch_embed = AscendQwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + self.blocks = nn.ModuleList([ + AscendQwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + grid_thw_tensor = torch.tensor(grid_thw, + device=self.device, + dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], + grid_thw_tensor[:, 0]).cpu().to(torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + cos=cos, + sin=sin) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class AscendQwen2_5_VLForConditionalGeneration_Without_Padding( + Qwen2_5_VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return image_embeds.split(sizes.tolist()) + + def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return video_embeds.split(sizes.tolist()) + + def _get_text_embeddings( + self, + input_ids: torch.Tensor, + get_input_embeddings: Callable[[torch.Tensor], torch.Tensor], + *, + is_multimodal: Optional[torch.Tensor], + handle_oov_mm_token: bool, + ) -> torch.Tensor: + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + + return get_input_embeddings(input_ids) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + """ + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. + """ + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + is_multimodal=is_multimodal, + multimodal_embeddings=multimodal_embeddings, + ) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen3VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class AscendQwen3VLMoeForConditionalGeneration( + Qwen3VLMoeForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + + def _get_text_embeddings( + self, + input_ids: torch.Tensor, + get_input_embeddings: Callable[[torch.Tensor], torch.Tensor], + *, + is_multimodal: Optional[torch.Tensor], + handle_oov_mm_token: bool, + ) -> torch.Tensor: + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + return get_input_embeddings(input_ids) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + """ + Apply token embeddings to `input_ids`. + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. + """ + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229.") + if self.use_deepstack: + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + is_multimodal=is_multimodal, + multimodal_embeddings=multimodal_embeddings, + ) + if deepstack_input_embeds is not None: + self._set_deepstack_input_embeds(deepstack_input_embeds) + return inputs_embeds + + def _compute_deepstack_embeds( + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + + visual_lens = [len(x) for x in multimodal_embeddings] + multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) + + total_dim = multimodal_embeddings_cat.shape[-1] + assert total_dim == self.visual_dim + self.multiscale_dim, \ + f"Total dimension mismatch: input {total_dim}, expected {self.visual_dim + self.multiscale_dim}" + multimodal_embeddings_main = multimodal_embeddings_cat[ + ..., :self.visual_dim] + multimodal_embeddings_multiscale = multimodal_embeddings_cat[ + ..., self.visual_dim:] + + multimodal_embeddings = torch.split(multimodal_embeddings_main, + visual_lens, + dim=0) + multimodal_embeddings_multiscale = torch.split( + multimodal_embeddings_multiscale, visual_lens, dim=0) + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), + self.deepstack_num_level * inputs_embeds.size(1)) + + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, + ) + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) + deepstack_input_embeds = deepstack_input_embeds.permute( + 1, 0, 2).contiguous() + + return deepstack_input_embeds, multimodal_embeddings diff --git a/vllm_npu/models/qwen2_vl.py b/vllm_npu/models/qwen2_vl.py new file mode 100644 index 0000000..14f560e --- /dev/null +++ b/vllm_npu/models/qwen2_vl.py @@ -0,0 +1,369 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/models/qwen2_vl.py +# This file is a part of the vllm-ascend project. + +from collections.abc import Iterable +from functools import partial +from typing import Callable, Optional, Set, Tuple, Type + +import torch +import torch.nn as nn +import torch_npu +from einops import rearrange +from transformers.models.qwen2_vl.configuration_qwen2_vl import \ + Qwen2VLVisionConfig +from vllm.config import VllmConfig +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed, + Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + +from vllm_npu.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz + +MIN_PAD_SIZE = 64 # min_size to pad weight +MAX_PAD_SIZE = 128 # max_size to pad weight + + +class AscendQwen2VisionAttention(Qwen2VisionAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + self.cu_seqlens = None + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head + if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: + self.hidden_size_per_attention_head = MAX_PAD_SIZE + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + + self.cu_seqlens = cu_seqlens + + # [s, b, c] --> [s, b, 3 * head * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + q = torch_npu.npu_rotary_mul(q, cos, sin) + k = torch_npu.npu_rotary_mul(k, cos, sin) + q, k, v = [ + rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + context_layer = torch.empty_like(q) + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=self.cu_seqlens, + scale_value=self.origin_hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer) + context_layer = rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class AscendQwen2VisionBlock(Qwen2VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(dim, num_heads, mlp_ratio, act_layer, norm_layer, + quant_config, prefix) + self.attn = AscendQwen2VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + cos=cos, + sin=sin, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1)) + return x + + +class AscendQwen2VisionTransformer(Qwen2VisionTransformer): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + interleaved=False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix) + + self.interleaved = interleaved + self.enable_pad = False + self.depth = vision_config.depth + self.hidden_size = vision_config.embed_dim + self.num_heads = vision_config.num_heads + self.patch_embed = AscendQwen2VisionPatchEmbed( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + embed_dim=vision_config.embed_dim, + ) + + self.blocks = nn.ModuleList([ + AscendQwen2VisionBlock(dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=vision_config.mlp_ratio, + norm_layer=partial(nn.LayerNorm, + eps=norm_eps), + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: + self.enable_pad = True + self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head + self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 + self.half_pad_hidden_size_per_attention_head = ( + MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 + self.hidden_size_per_attention_head = MAX_PAD_SIZE + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + if self.enable_pad: + cos = torch.nn.functional.pad( + cos, (0, self.half_pad_hidden_size_per_attention_head)) + sin = torch.nn.functional.pad( + sin, (0, self.half_pad_hidden_size_per_attention_head)) + + if not self.interleaved: + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + else: + cos_new = rearrange(torch.stack((cos, cos), dim=-1), + "... d two -> ...(d two)", + two=2) + sin_new = rearrange(torch.stack((sin, sin), dim=-1), + "... d two -> ...(d two)", + two=2) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def pad_qkv_bias(self, bias): + first_half = bias.reshape( + -1, 3, self.origin_hidden_size_per_attention_head + )[:, :, :self.half_origin_hidden_size_per_attention_head] + second_half = bias.reshape( + -1, 3, self.origin_hidden_size_per_attention_head + )[:, :, self.half_origin_hidden_size_per_attention_head:] + first_half_padded = torch.nn.functional.pad( + first_half, (0, self.half_pad_hidden_size_per_attention_head)) + second_half_padded = torch.nn.functional.pad( + second_half, (0, self.half_pad_hidden_size_per_attention_head)) + bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) + bias_final = bias_padded.reshape(-1) + return bias_final + + def pad_qkv_weight(self, data): + qkv_weight_first_half = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size + )[:, :, :self.half_origin_hidden_size_per_attention_head, :] + qkv_weight_second_half = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size + )[:, :, self.half_origin_hidden_size_per_attention_head:, :] + + qkv_weight_first_half_padded = torch.nn.functional.pad( + qkv_weight_first_half, + (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) + qkv_weight_second_half_padded = torch.nn.functional.pad( + qkv_weight_second_half, + (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) + qkv_weight_padded = torch.cat( + [qkv_weight_first_half_padded, qkv_weight_second_half_padded], + dim=2) + qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) + + if is_enable_nz(qkv_weight_final.dtype): + qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( + qkv_weight_final) + qkv_weight_final_copy = torch_npu.npu_format_cast( + qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) + return qkv_weight_final_copy + + return qkv_weight_final + + def pad_proj_weight(self, data): + out_weight = torch.nn.functional.pad( + data.reshape(self.hidden_size, -1, + self.half_origin_hidden_size_per_attention_head), + (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( + self.hidden_size, -1) + + if is_enable_nz(out_weight.dtype): + out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) + out_weight_copy = torch_npu.npu_format_cast( + out_weight_copy, ACL_FORMAT_FRACTAL_ND) + return out_weight_copy + + return out_weight + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if ("attn.proj.weight" in name) and self.enable_pad: + param.data = self.pad_proj_weight(param.data) + if ("attn.qkv.weight" in name) and self.enable_pad: + param.data = self.pad_qkv_weight(param.data) + if ("attn.qkv.bias" in name) and self.enable_pad: + param.data = self.pad_qkv_bias(param.data) + loaded_params.add(name) + return loaded_params + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + # compute cu_seqlens and avoid cumsum to fit operator unpadFA + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, + 0]).cpu().to(torch.int32) + + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk(x, cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + # adapter + x = self.merger(x) + return x + + +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.visual = AscendQwen2VisionTransformer( + self.config.vision_config, + norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) diff --git a/vllm_npu/models/qwen3_next.py b/vllm_npu/models/qwen3_next.py new file mode 100644 index 0000000..47b6d3e --- /dev/null +++ b/vllm_npu/models/qwen3_next.py @@ -0,0 +1,676 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# mypy: ignore-errors +"""Inference-only Qwen3Next model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN +from vllm import envs +from vllm.attention import AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, + VllmConfig, get_current_vllm_config) +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fla.ops import RMSNormGated +from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops.fused_recurrent import \ + fused_recurrent_gated_delta_rule +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.layernorm import \ + GemmaRMSNorm as Qwen3NextRMSNorm +# yapf: enable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import \ + mamba_v2_sharded_weight_loader +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.utils import set_weight_attrs +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from vllm.model_executor.models.qwen3_next import ( # isort: skip + Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, + Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, + fused_gdn_gating) + + +class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, + self.head_v_dim, self.conv_kernel_size, self.num_spec) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = (self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.projection_size_qkvz, self.projection_size_ba], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + query_key_settings, + query_key_settings, + value_settings, + ], self.tp_size, self.tp_rank) + }) + + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + )) + + set_weight_attrs(self.A_log, + {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + norm_before_gate=True, + device="npu", + ) + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj") + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + + num_actual_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens + + attn_metadata.num_spec_decode_tokens) + num_accepted_tokens = attn_metadata.num_accepted_tokens + + # 1. Set up dimensions for reshapes later + projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + projected_states_qkvz, projected_states_ba = torch.split( + projected_states, + [ + self.projection_size_qkvz // self.tp_size, + self.projection_size_ba // self.tp_size + ], + dim=-1, + ) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + # validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + + batch_size = initial_state.shape[0] + core_attn_out = [] + last_recurrent_state = [] + + for b_idx in range(batch_size): + start, end = non_spec_query_start_loc[ + b_idx], non_spec_query_start_loc[b_idx + 1] + cur_q = query_non_spec[:, start:end, ...] + cur_k = key_non_spec[:, start:end, ...] + cur_v = value_non_spec[:, start:end, ...] + cur_g = g_non_spec[:, start:end, ...] + cur_b = beta_non_spec[:, start:end, ...] + cur_state = initial_state[b_idx].unsqueeze(0) + + ( + cur_core_attn_out_non_spec, + cur_last_recurrent_state, + ) = chunk_gated_delta_rule( + query=cur_q, + key=cur_k, + value=cur_v, + g=cur_g, + beta=cur_b, + initial_state=cur_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + + core_attn_out.append(cur_core_attn_out_non_spec) + last_recurrent_state.append(cur_last_recurrent_state) + + tar_dtype = core_attn_out[0].dtype + tar_device = core_attn_out[0].device + tar_shape = list(core_attn_out[0].shape) + tar_shape[1] = non_spec_query_start_loc[-1] + core_attn_out_non_spec = torch.empty(tar_shape, + dtype=tar_dtype, + device=tar_device) + for b_idx in range(batch_size): + cur_core_attn_out = core_attn_out[b_idx] + start, end = non_spec_query_start_loc[ + b_idx], non_spec_query_start_loc[b_idx + 1] + core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out + last_recurrent_state = torch.cat(last_recurrent_state, dim=0) + + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): + + def __init__( + self, + vllm_config: VllmConfig, + layer_type: str, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = CustomQwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f'{prefix}.linear_attn') + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.self_attn', + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (self.layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.torch_dtype, + ), ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.torch_dtype, + ), ) + + +@support_torch_compile +class CustomQwen3NextModel(Qwen3NextModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config: Qwen3NextConfig = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return CustomQwen3NextDecoderLayer( + vllm_config, + layer_type=config.layer_types[extract_layer_index(prefix)], + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("in_proj", "in_proj_qkvz", 0), + ("in_proj", "in_proj_ba", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3Next currently does not support prefix caching" + assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" + self.quant_config = vllm_config.quant_config + self.config = config + self.scheduler_config = scheduler_config + self.model = CustomQwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts diff --git a/vllm_npu/multistream/__init__.py b/vllm_npu/multistream/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/multistream/base.py b/vllm_npu/multistream/base.py new file mode 100644 index 0000000..fba58b4 --- /dev/null +++ b/vllm_npu/multistream/base.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from enum import Enum + + +class MSEventKey(Enum): + ATTN_COM_FINISH = 0 + ATTN_AR_FINISH = 1 + FFN_COM_FINISH = 2 + FFN_AR_FINISH = 3 + # events for MOE dispatch and combine + MOE_BEFORE_COMM = 4 + MOE_AFTER_COMM = 5 + # events for shared expert + MOE_SE_COMM_FINISH = 6 + MOE_SE_COMP_FINISH = 7 + MOE_GATE_FINISH = 8 + + +@dataclass +class MSAttentionMetadataSplitConfig: + """ + micro batch split config for split attention metadata + """ + # micro batch num + num_micro_batches: int = 2 + # split micro batches only when total tokens >= min_total_tokens_to_split + min_total_tokens_to_split: int = 256 + # split micro batches only when prefill tokens >= min_prefill_tokens_to_split + min_prefill_tokens_to_split: int = 64 diff --git a/vllm_npu/multistream/context.py b/vllm_npu/multistream/context.py new file mode 100644 index 0000000..a1684f2 --- /dev/null +++ b/vllm_npu/multistream/context.py @@ -0,0 +1,67 @@ +from contextlib import contextmanager +from typing import Any + +_ms_comm_context: Any = None +_cur_micro_batch_num: int = -1 +_ms_layer_index_context: int = -1 +_ms_metadata_context: Any = None +_ms_attn_metadata_context: Any = None + + +def set_multistream_layer_context(start_layer: int, ms_metadata: Any, + attn_metadata: Any): + """ + set multistream layer context before transformer layers + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = start_layer + _ms_metadata_context = ms_metadata + _ms_attn_metadata_context = attn_metadata + + +def reset_multistream_layer_context(): + """ + reset multistream layer context + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = -1 + _ms_metadata_context = None + _ms_attn_metadata_context = None + + +def get_multistream_layer_context(): + """ + get multistream layer context + """ + return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + + +def advance_step_multistream_layer_context(): + """ + advance multistream layer index context + """ + global _ms_layer_index_context + _ms_layer_index_context += 1 + + +def get_multistream_comm_context() -> Any: + """Get the current comm forward context.""" + return _ms_comm_context + + +def get_multistream_microbatch_context() -> int: + return _cur_micro_batch_num + + +@contextmanager +def set_multistream_context(context: Any, micro_batch_num: int): + """A context manager that stores the current comm forward context, + can be attention metadata, etc.""" + global _ms_comm_context, _cur_micro_batch_num + _ms_comm_context = context + _cur_micro_batch_num = micro_batch_num + try: + yield + finally: + _ms_comm_context = None + _cur_micro_batch_num = -1 diff --git a/vllm_npu/multistream/decorator.py b/vllm_npu/multistream/decorator.py new file mode 100644 index 0000000..5b573df --- /dev/null +++ b/vllm_npu/multistream/decorator.py @@ -0,0 +1,22 @@ +from .context import (get_multistream_layer_context, + get_multistream_microbatch_context) + + +# vllm v1 use get_forward_context to get the attn_metadata, +# we can use this decorator to update the attn metadata +def set_multistream_support(): + + def decorator(func): + + def wrapper(): + context = func() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + micro_batch_num = get_multistream_microbatch_context() + if layer_index != -1 and micro_batch_num != -1: + context.attn_metadata = attn_metadata[micro_batch_num] + return context + + return wrapper + + return decorator diff --git a/vllm_npu/multistream/layers.py b/vllm_npu/multistream/layers.py new file mode 100644 index 0000000..c5273bc --- /dev/null +++ b/vllm_npu/multistream/layers.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Tuple, Union + +import torch +from vllm.forward_context import get_forward_context + +from .base import MSEventKey +from .context import (get_multistream_layer_context, + reset_multistream_layer_context, + set_multistream_layer_context) +from .metadata import MultiStreamMetadata + + +class MultiStreamPreTransformerLayer(torch.nn.Module): + + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + + def forward( + self, + intput_tensors: List[torch.Tensor], + ): + attn_metadata = get_forward_context().attn_metadata + if self.multistream_metadata is None or attn_metadata is None: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + # TODO add attn_metadata management + do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch( + attn_metadata, intput_tensors) + if do_ms: + set_multistream_layer_context( + self.multistream_metadata.start_layer, + self.multistream_metadata, attn_metadata) + else: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + + +class MultiStreamPostTransformerLayer(torch.nn.Module): + + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + + def forward(self, + input_tensors: Union[List[Tuple[torch.Tensor]], + List[torch.Tensor], + List[List[torch.Tensor]]], + wait_layer_index: Optional[int] = None): + if self.multistream_metadata is None or self.multistream_metadata.ms_config is None: + return input_tensors + layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context( + ) + if layer_index >= 0: + true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index + self.multistream_metadata.try_wait_event( + true_wait_layer, + self.multistream_metadata.ms_config.num_micro_batches - 1, + MSEventKey.FFN_AR_FINISH) + reset_multistream_layer_context() + return self.multistream_metadata.merge_micro_batches(input_tensors) diff --git a/vllm_npu/multistream/metadata.py b/vllm_npu/multistream/metadata.py new file mode 100644 index 0000000..27a047e --- /dev/null +++ b/vllm_npu/multistream/metadata.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from vllm.sequence import IntermediateTensors + +from vllm_npu.attention.mla_v1 import AscendMLAMetadata + +from .base import MSAttentionMetadataSplitConfig, MSEventKey + + +def split_micro_batches_tensors(input_tensors, + split_index: int, + keys: Optional[List[str]] = None): + if isinstance(input_tensors, list): + micro_batches = [] + for tensor in input_tensors: + if tensor is None: + micro_batches.append([None, None]) + else: + micro_batches.append( + [tensor[:split_index], tensor[split_index:]]) + return micro_batches + elif isinstance(input_tensors, torch.Tensor): + return [input_tensors[:split_index], input_tensors[split_index:]] + elif input_tensors is None: + return [None, None] + elif isinstance(input_tensors, Dict): + assert keys is not None + micro_batches_pre = {} + for key in keys: + micro_batches_pre[key] = input_tensors[key][:split_index] + micro_batches_post = {} + for key in keys: + micro_batches_post[key] = input_tensors[key][split_index:] + return [micro_batches_pre, micro_batches_post] + else: + raise NotImplementedError + + +@dataclass +class MultiStreamStepMetadata: + comm_stream: torch.npu.Stream = None + before_comm_event: torch.npu.Event = None + after_comm_event: torch.npu.Event = None + + +@dataclass +class MultiStreamConfig: + """Controls the behavior of multi-stream models.""" + min_total_tokens_to_split: int = 256 + min_prefill_tokens_to_split: int = 64 + num_micro_batches: int = 2 + imbalance_ratio: float = 0.1 + + +class MultiStreamMetadata: + # direct stream + calculate_stream = None + # delay stream + communicate_stream = None + # events + ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {} + # multi-stream-flag + enable_multi_stream: bool = False + + def __init__( + self, + calculate_stream: torch.npu.Stream, + communicate_stream: torch.npu.Stream, + start_layer: int, + end_layer: int, + event_keys: List[MSEventKey], + multistream_config: Optional[MultiStreamConfig], + causal_lm: bool = True, + ): + self.calculate_stream = calculate_stream + self.communicate_stream = communicate_stream + self.start_layer = start_layer + self.end_layer = end_layer + self.ms_config = multistream_config + self.causal_lm = causal_lm + self._build_events(event_keys) + self._build_ms_split_config() + + def _build_events(self, event_keys): + if self.ms_config is not None: + for i in range(self.start_layer - 1, self.end_layer): + self.ms_events[i] = {} + for j in range(self.ms_config.num_micro_batches): + self.ms_events[i][j] = {} + for key in event_keys: + self.ms_events[i][j][key] = torch.npu.Event() + + def _build_ms_split_config(self): + if self.ms_config is not None: + self.ms_split_config = MSAttentionMetadataSplitConfig( + num_micro_batches=self.ms_config.num_micro_batches, + min_total_tokens_to_split=self.ms_config. + min_total_tokens_to_split, + min_prefill_tokens_to_split=self.ms_config. + min_prefill_tokens_to_split, + ) + + def try_wait_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].wait() + + def try_record_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].record() + + def split_micro_batch( + self, + attn_metadata: "AscendMLAMetadata", + intput_tensors: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors_keys: Optional[List[str]] = None, + ) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[ + List[torch.Tensor], List[List[torch.Tensor]]], Union[ + IntermediateTensors, List[IntermediateTensors]]]: + attn_metadata_list = attn_metadata.split_metadata_for_multistream( + self.ms_split_config) + if len(attn_metadata_list) == 1: + return False, attn_metadata_list[ + 0], intput_tensors, intermediate_tensors + split_index = attn_metadata_list[0].slot_mapping.shape[0] + input_tensors = split_micro_batches_tensors(intput_tensors, + split_index) + if intermediate_tensors is not None: + inter_tensors_list = split_micro_batches_tensors( + intermediate_tensors.tensors, split_index, + intermediate_tensors_keys) + intermediate_tensors = [ + IntermediateTensors(inter_tensors) + for inter_tensors in inter_tensors_list + ] + return True, attn_metadata_list, input_tensors, intermediate_tensors + + def merge_micro_batches( + self, input_tensors: Union[List[torch.Tensor], + List[List[torch.Tensor]]] + ) -> List[torch.Tensor]: + if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): + return input_tensors + batch: List[Optional[torch.Tensor]] = [] + for tensors in input_tensors: + if tensors is None or tensors[0] is None: + batch.append(None) + else: + batch.append(torch.cat(tensors, dim=0)) + return batch + + +def make_multistream_metadata_ds( + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, +): + if multistream_config is None: + return None + event_keylist = [ + MSEventKey.ATTN_COM_FINISH, + MSEventKey.ATTN_AR_FINISH, + MSEventKey.FFN_COM_FINISH, + MSEventKey.FFN_AR_FINISH, + MSEventKey.MOE_BEFORE_COMM, + MSEventKey.MOE_AFTER_COMM, + MSEventKey.MOE_SE_COMM_FINISH, + MSEventKey.MOE_SE_COMP_FINISH, + MSEventKey.MOE_GATE_FINISH, + ] + return MultiStreamMetadata( + calculate_stream=torch.npu.current_stream(), + communicate_stream=torch.npu.Stream(), + start_layer=start_layer, + end_layer=end_layer, + multistream_config=multistream_config, + event_keys=event_keylist, + causal_lm=causal_lm, + ) diff --git a/vllm_npu/multistream/ms_split.py b/vllm_npu/multistream/ms_split.py new file mode 100644 index 0000000..e1eb8b3 --- /dev/null +++ b/vllm_npu/multistream/ms_split.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from typing import Any, List, Optional + +import numpy as np +import torch + +from vllm_npu.attention.attention_v1 import AscendAttentionState + +from .base import MSAttentionMetadataSplitConfig + + +def compute_split_seq_index( + query_lens: Optional[list[int]], + attn_state: AscendAttentionState, + num_tokens: int, + imbalance_ratio: float = 0.1, +) -> list[int]: + if attn_state != AscendAttentionState.DecodeOnly: + assert query_lens is not None + total_tokens = sum(query_lens) + # the first index in last split + tokens, split_index = 0, 0 + for value in query_lens: + tokens += value + split_index += 1 + if tokens >= total_tokens // 2: + # check the current split index + if abs(tokens - + total_tokens // 2) < total_tokens * imbalance_ratio: + return [tokens, split_index] + # check the previous split index + elif abs(tokens - total_tokens // 2 - + value) < total_tokens * imbalance_ratio: + return [tokens - value, split_index - 1] + # fail to split if it is imbalanced + # TODO: split tokens in seq + else: + return [0, 0] + else: + tokens = num_tokens // 2 + return [tokens, tokens] + return [0, 0] + + +def split_attn_tensor_type( + input_tensor: torch.Tensor, + index: int, +) -> List[torch.Tensor]: + return [input_tensor[:index], input_tensor[index:]] + + +def split_attn_int_type( + var: int, + index: int, +) -> List[torch.Tensor]: + return [min(var, index), max(var - index, 0)] + + +def model_input_split_v1_mla_attn( + attn_metadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, +) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + if attn_metadata is None: + return [attn_metadata] + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] + + query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ), + dtype=int) + np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) + if attn_metadata.num_prefills > 0: + prefill_query_start_loc = np.zeros( + shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int) + np.cumsum(attn_metadata.prefill.query_lens, + out=prefill_query_start_loc[1:]) + + # split attn metadata + [slot_mapping_pre, + slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, + token_index) + [num_decodes_pre, + num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, + seq_index) + [num_decode_tokens_pre, num_decode_tokens_post + ] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) + [num_prefills_pre, num_prefills_post + ] = split_attn_int_type(attn_metadata.num_prefills, + max(0, seq_index - attn_metadata.num_decodes)) + seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens + [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] + [block_table_pre, + block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, + seq_index) + assert attn_metadata.attn_mask is not None + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + else: + # chunked prefill + if num_prefills_pre > 0: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + from vllm_npu.attention.mla_v1 import (AscendMLADecodeMetadata, + AscendMLAPrefillMetadata) + if num_prefills_pre > 0: + # split metadata.prefill + [input_positions_pre, input_positions_post] = split_attn_tensor_type( + attn_metadata.prefill.input_positions, + token_index - attn_metadata.num_decode_tokens) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.prefill.block_table, + seq_index - attn_metadata.num_decodes) + [prefill_query_lens_pre, prefill_query_lens_post + ] = split_attn_tensor_type(attn_metadata.prefill.query_lens, + seq_index - attn_metadata.num_decodes) + prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[: + seq_index + + + 1 - + attn_metadata + . + num_decodes] + prefill_query_start_loc_post = deepcopy( + attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes:] + ) - attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes] + context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] + context_len_post = seq_lens_post + prefill_max_query_len_pre = max(prefill_query_lens_pre) + prefill_max_query_len_post = max(prefill_query_lens_post) + prefill_pre = AscendMLAPrefillMetadata( + attn_mask=attn_mask_pre, + query_lens=prefill_query_lens_pre, + seq_lens=seq_lens_pre, + query_start_loc=prefill_query_start_loc_pre, + input_positions=input_positions_pre, + context_lens=context_len_pre, + block_table=block_tables_pre, + max_query_len=prefill_max_query_len_pre, + max_seq_lens=context_len_pre.max().item(), + ) + prefill_post = AscendMLAPrefillMetadata( + attn_mask=attn_mask_post, + query_lens=prefill_query_lens_post, + seq_lens=seq_lens_post, + query_start_loc=prefill_query_start_loc_post, + input_positions=input_positions_post, + context_lens=context_len_post, + block_table=block_tables_post, + max_query_len=prefill_max_query_len_post, + max_seq_lens=context_len_post.max().item(), + ) + decode_pre = attn_metadata.decode + decode_post = None + else: + # prefill is None, split metadata.decode + [input_positions_pre, input_positions_post + ] = split_attn_tensor_type(attn_metadata.decode.input_positions, + token_index) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.decode.block_table, + seq_index) + [decode_seq_lens_pre, + decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + decode_pre = AscendMLADecodeMetadata( + input_positions=input_positions_pre, + block_table=block_tables_pre, + seq_lens=decode_seq_lens_pre, + max_seq_lens=max(decode_seq_lens_pre), + seq_lens_list=decode_seq_lens_pre.tolist(), + ) + decode_post = AscendMLADecodeMetadata( + input_positions=input_positions_post, + block_table=block_tables_post, + seq_lens=decode_seq_lens_post, + max_seq_lens=max(decode_seq_lens_post), + seq_lens_list=decode_seq_lens_post.tolist(), + ) + prefill_pre = None + prefill_post = attn_metadata.prefill + # construct metadata + from vllm_npu.attention.mla_v1 import AscendMLAPrefillMetadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + num_input_tokens=token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_pre, + seq_lens=seq_lens_pre, + query_start_loc=query_start_loc_pre, + block_tables=block_table_pre, + num_decodes=num_decodes_pre, + num_prefills=num_prefills_pre, + num_decode_tokens=num_decode_tokens_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + prefill=prefill_pre, + decode=decode_pre, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_post, + seq_lens=seq_lens_post, + query_start_loc=query_start_loc_post, + block_tables=block_table_post, + num_decodes=num_decodes_post, + num_prefills=num_prefills_post, + num_decode_tokens=num_decode_tokens_post, + attn_mask=attn_mask_post, + attn_state=attn_state_post, + prefill=prefill_post, + decode=decode_post, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) + return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_npu/ops/__init__.py b/vllm_npu/ops/__init__.py index 72bb3b3..1304b99 100644 --- a/vllm_npu/ops/__init__.py +++ b/vllm_npu/ops/__init__.py @@ -1 +1,57 @@ -"""Ascend NPU custom op registrations.""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch + +import vllm_npu.ops.common_fused_moe # noqa +import vllm_npu.ops.layernorm # noqa +import vllm_npu.ops.register_custom_ops # noqa +import vllm_npu.ops.vocab_parallel_embedding # noqa +from vllm_npu.ops.activation import AscendQuickGELU, AscendSiluAndMul +from vllm_npu.ops.rotary_embedding import ( + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + + +class dummyFusionOp: + default = None + + def __init__(self, name=""): + self.name = name + + +def register_dummy_fusion_op() -> None: + torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm") + torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp( + name="fused_add_rms_norm") + torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp( + name="static_scaled_fp8_quant") + torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp( + name="dynamic_scaled_fp8_quant") + torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( + name="dynamic_per_token_scaled_fp8_quant") + torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp( + name="rms_norm_static_fp8_quant") + torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( + name="fused_add_rms_norm_static_fp8_quant") + torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp( + name="rms_norm_dynamic_per_token_quant") + + +__all__ = [ + "AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding", + "AscendDeepseekScalingRotaryEmbedding" +] diff --git a/vllm_npu/ops/activation.py b/vllm_npu/ops/activation.py index d8e8ecb..bd935f8 100644 --- a/vllm_npu/ops/activation.py +++ b/vllm_npu/ops/activation.py @@ -1,17 +1,44 @@ -""" -NPU-optimized activation functions for Ascend. - -Provides ``AscendSiluAndMul`` that uses ``torch_npu.npu_swiglu`` for -fused SiLU+Mul on NPU devices. -""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# import torch -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul + + +class AscendQuickGELU(QuickGELU): + + def forward_oot(self, x: torch.tensor) -> torch.Tensor: + import torch_npu + + out = torch_npu.npu_fast_gelu(x) + return out class AscendSiluAndMul(SiluAndMul): - """SiluAndMul using torch_npu.npu_swiglu on Ascend NPU.""" def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - import torch_npu # noqa: F401 - return torch_npu.npu_swiglu(x) + import torch_npu + + from vllm_npu.utils import is_310p + + torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) + if is_310p(): + out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) + else: + out = torch_npu.npu_swiglu(x) + torch.ops.vllm.maybe_wait_prefetch_done(out) + return out diff --git a/vllm_npu/ops/attention.py b/vllm_npu/ops/attention.py new file mode 100644 index 0000000..05600ae --- /dev/null +++ b/vllm_npu/ops/attention.py @@ -0,0 +1,309 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/kernels/test_moe.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +from vllm.model_executor.layers.linear import ColumnParallelLinear + + +# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for +# all the corner case +def vanilla_chunked_prefill( + output: torch.Tensor, + query: torch.Tensor, # (num_tokens, heads, head_size) + key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size) + value_cache: torch. + Tensor, # (num_blocks, block_size, kv_heads, head_size,) + block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq) + cu_seqlen_q: torch.Tensor, # (num_seqs + 1,) + cu_seqlen_k: torch.Tensor, # (num_seqs + 1,) + max_seqlen_q: int, + max_seqlen_k: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool = True, +) -> torch.Tensor: + num_query_heads = query.shape[1] + head_dim = value_cache.shape[3] + num_kv_heads = value_cache.shape[2] + block_size = value_cache.shape[1] + num_batch = cu_seqlen_q.shape[0] - 1 + max_num_blocks_per_seq = block_tables.shape[1] + + key = key_cache[block_tables].view(num_batch, + max_num_blocks_per_seq * block_size, + num_kv_heads, head_dim) + + value = value_cache[block_tables].view(num_batch, + max_num_blocks_per_seq * block_size, + num_kv_heads, head_dim) + key = key[:, :max_seqlen_k, :, :] + value = value[:, :max_seqlen_k, :, :] + + seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1] + seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1] + seqlen_q = seqlen_q.view(-1, 1) + seqlen_k = seqlen_k.view(-1, 1) + seqlen_diff = seqlen_k - seqlen_q + q_idx_mask = (torch.arange(0, max_seqlen_q, + device="npu").view(1, -1).repeat(num_batch, 1)) + k_idx_mask = (torch.arange(0, max_seqlen_k, + device="npu").view(1, -1).repeat(num_batch, 1)) + q_mask = q_idx_mask < seqlen_q + k_mask = k_idx_mask < seqlen_k + + # calculate idx for causal mask of query [batch, max_seqlen_q] + causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask] + + # generate causal mask [batch, max_seqlen_q, max_seqlen_k] + tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k, + device="npu")) + tril_mask[tril_mask == 0] = float("-inf") + tril_mask[tril_mask == 1] = 0 + causal_mask = tril_mask[causal_mask_idx] + causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k], + device="npu").fill_(float("-inf")) + causal_mask_padding[q_mask] = causal_mask + # to [batch, num_heads, max_seqlen_q, max_seqlen_k] + causal_mask_padding = causal_mask_padding.unsqueeze(1) + + pad_q = torch.zeros( + [num_batch, max_seqlen_q, num_query_heads, head_dim], + device="npu", + dtype=query.dtype, + ) + pad_k = torch.zeros( + [num_batch, max_seqlen_k, num_kv_heads, head_dim], + device="npu", + dtype=key.dtype, + ) + pad_v = torch.zeros( + [num_batch, max_seqlen_k, num_kv_heads, head_dim], + device="npu", + dtype=value.dtype, + ) + pad_q[q_mask] = query + pad_k[k_mask] = key[k_mask] + pad_v[k_mask] = value[k_mask] + + if num_query_heads > num_kv_heads: + pad_k = pad_k.view( + [num_batch, max_seqlen_k, num_kv_heads, 1, head_dim]) + pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view( + [num_batch, max_seqlen_k, num_query_heads, head_dim]) + pad_v = pad_v.view( + [num_batch, max_seqlen_k, num_kv_heads, 1, head_dim]) + pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view( + [num_batch, max_seqlen_k, num_query_heads, head_dim]) + # permute to [b, h, n, k] + pad_q = pad_q.permute(0, 2, 1, 3) + pad_k = pad_k.permute(0, 2, 1, 3) + pad_v = pad_v.permute(0, 2, 1, 3) + attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k], + device="npu").fill_(float("-inf")) + attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0) + # [b, h, f, t] + attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k) + attn_weights *= scale + attn_mask = attn_mask.float() + attn_weights = attn_weights + attn_mask + if causal: + attn_weights = attn_weights + causal_mask_padding + + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float()) + attn_output = attn_output.permute(0, 2, 1, 3) + + attn_output = (attn_output[q_mask].view([-1, num_query_heads, + head_dim]).to(output.dtype)) + output.copy_(attn_output) + return attn_output + + +def vanilla_chunked_prefill_mla( + output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) + query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) + kv_cache: Tuple[ + torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv) + block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) + query_lens: torch.Tensor, # (batch_size) + context_lens: torch.Tensor, # (batch_size) + kv_b_proj: ColumnParallelLinear, # () + max_query_len: int, + max_context_len: int, + nope_dim: int, + rope_dim: int, + v_head_dim: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool = True) -> None: + batch_size = block_tables.size(0) + assert len(kv_cache) > 1 + assert query_lens.size(0) == batch_size + num_heads = query.size(1) + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + block_size = nope_cache.size(1) + latent_kv_dim = nope_cache.size(-1) + max_num_blocks_per_seq = block_tables.size(1) + batch_size = query_lens.size(0) + nope_cache = nope_cache.squeeze() + # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe + # cached_kv_c: [batch_size, max_context_len, latent_kv] + # cached_k_pe: [batch_size, max_context_len, rope_dim] + cache_kv_c = nope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + latent_kv_dim)[:, :max_context_len, :] + cache_k_pe = rope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + rope_dim)[:, :max_context_len, :] + # get k_rope and v + # k_nope: [batch_size, max_context_len, num_heads, nope_dim] + # value: [batch_size, max_context_len, num_heads, v_head_dim] + k_nope, value = kv_b_proj(cache_kv_c)[0].view( + batch_size, max_context_len, num_heads, + nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1) + # key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim] + key = torch.cat( + [k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)], + dim=-1) + + context_lens = context_lens.view(-1, 1).to("npu") + query_lens = query_lens.view(-1, 1).to("npu") + seq_diff = context_lens - query_lens + + q_idx_mask = (torch.arange(0, max_query_len, + device="npu").view(1, -1).repeat(batch_size, 1)) + kv_c_idx_mask = (torch.arange(0, max_context_len, + device="npu").view(1, + -1).repeat(batch_size, 1)) + kv_c_mask = kv_c_idx_mask < context_lens + q_mask = q_idx_mask < query_lens + + # calculate idx for causal mask of query [batch, max_seqlen_q] + causal_mask_idx = (q_idx_mask + seq_diff)[q_mask] + + # generate causal mask [batch, max_seqlen_q, max_seqlen_k] + tril_mask = torch.tril( + torch.ones(max_context_len, max_context_len, device="npu")) + tril_mask[tril_mask == 0] = float("-inf") + tril_mask[tril_mask == 1] = 0 + causal_mask = tril_mask[causal_mask_idx] + causal_mask_padding = torch.empty( + [batch_size, max_query_len, max_context_len], + device="npu").fill_(float("-inf")) + causal_mask_padding[q_mask] = causal_mask + # to [batch, num_heads, max_seqlen_q, max_seqlen_k] + causal_mask_padding = causal_mask_padding.unsqueeze(1) + + pad_q = torch.zeros( + [batch_size, max_query_len, num_heads, rope_dim + nope_dim], + device="npu", + dtype=query.dtype, + ) + pad_k = torch.zeros( + [batch_size, max_context_len, num_heads, rope_dim + nope_dim], + device="npu", + dtype=key.dtype, + ) + pad_v = torch.zeros( + [batch_size, max_context_len, num_heads, v_head_dim], + device="npu", + dtype=value.dtype, + ) + num_query = torch.sum(q_mask).item() + num_add_query = num_query - query.size(0) + # mtp will come in + if num_add_query > 0: + add_query_size = query.size() + add_query_size = list(add_query_size) + add_query_size[0] = num_add_query + pad_tensor = torch.zeros(add_query_size, + dtype=query.dtype, + device=query.device) + query = torch.cat([query, pad_tensor], dim=0) + pad_q[q_mask] = query + pad_k[kv_c_mask] = key[kv_c_mask] + pad_v[kv_c_mask] = value[kv_c_mask] + + pad_q = pad_q.permute(0, 2, 1, 3) + pad_k = pad_k.permute(0, 2, 1, 3) + pad_v = pad_v.permute(0, 2, 1, 3) + attn_mask = torch.empty([batch_size, 1, 1, max_context_len], + device="npu").fill_(float("-inf")) + attn_mask[:, :, :, :max_context_len].masked_fill_( + kv_c_mask[:, None, None, :], 0) + # [b, h, f, t] + attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k) + attn_weights *= scale + attn_mask = attn_mask.float() + attn_weights = attn_weights + attn_mask + if causal: + attn_weights = attn_weights + causal_mask_padding + + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float()) + attn_output = attn_output.permute(0, 2, 1, 3) + + attn_output = (attn_output[q_mask].view([-1, num_heads, + v_head_dim]).to(output.dtype)) + attn_output = attn_output.view_as(output) + output.copy_(attn_output) + return attn_output + + +def vanilla_decode_mla( + query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim] + key_cache: torch. + Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim] + num_kv_heads: int, + num_heads: int, + scale: float, + block_table: torch.Tensor, # [batch_size, max_block_size] + context_lens: List[int], + mla_vhead_size: int, + rope_dim: int, + output: torch.Tensor): + batch_size = block_table.size()[0] + max_block_size = block_table.size()[1] + reduce_dim = key_cache.size()[-1] + block_size = key_cache.size()[1] + latent_dim = reduce_dim - rope_dim + kv_c_and_pe = key_cache[block_table].view( + [batch_size, max_block_size * block_size, num_kv_heads, reduce_dim]) + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1) + # [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim] + # since the kv head is 1 in deepseek, we use expand here for perf + kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand( + -1, -1, num_heads, 1) + kv_c = kv_c_and_pe[..., :latent_dim] + kv_idx_mask = (torch.arange(0, max_context_len, + device="npu").view(1, + -1).repeat(batch_size, 1)) + # [batch_size, max_context_len] + kv_idx_mask = kv_idx_mask < context_lens + query = query.unsqueeze(1) + attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe) + attn_weights *= scale + attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float() + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, + kv_c.float()).view(-1, num_heads, latent_dim) + output.copy_(attn_output) + return output diff --git a/vllm_npu/ops/casual_conv1d.py b/vllm_npu/ops/casual_conv1d.py new file mode 100644 index 0000000..2d00889 --- /dev/null +++ b/vllm_npu/ops/casual_conv1d.py @@ -0,0 +1,539 @@ +# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# mypy: ignore-errors + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]].unsqueeze(0), + initial_states=conv_states[cache_indices[i]] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim + ) # [BLOCK_N] + + x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = (conv_state_base + + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + # mask_1d = (idx_token < seqlen) & ( + # idx_feats < dim + # ) # token-index # feature-index + maskL = idx_feats < dim + maskR = tl.full(maskL.shape, False, tl.int1) + mask_1d = tl.where(idx_token < seqlen, maskL, maskR) + + o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + (idx_feats * stride_o_dim)) + + tl.store(o_ptrs, acc, mask=mask_1d) + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = (intermediate_conv_window_ptr + + conv_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + +def causal_conv1d_update_npu( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch, ) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride( + ) # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + ) + stride_state_indices = (conv_state_indices.stride(0) + if conv_state_indices is not None else 0) + state_len = width - 1 + (seqlen - 1) # effective state_len needed + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window + if intermediate_conv_window is not None else x, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=128, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/vllm_npu/ops/common_fused_moe.py b/vllm_npu/ops/common_fused_moe.py new file mode 100644 index 0000000..46039e8 --- /dev/null +++ b/vllm_npu/ops/common_fused_moe.py @@ -0,0 +1,451 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os.path +from typing import Any, Callable, Optional + +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, + tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context +from vllm.logger import logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map, + get_compressed_expert_map) +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.ascend_forward_context import MoECommType +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.eplb.core.eplb_utils import determine_default_log2phy_map +from vllm_npu.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_npu.ops.moe.experts_selector import select_experts +from vllm_npu.ops.moe.moe_comm_method import setup_moe_comm_method +from vllm_npu.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, + is_enable_nz, npu_stream_switch, + shared_expert_dp_enabled, + shared_experts_compute_stream) + + +class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + + def __init__(self, moe: FusedMoEConfig = None): + + super().__init__(moe=moe) + self.dynamic_eplb = get_ascend_config().dynamic_eplb + self.transpose = True + + def process_weights_after_loading(self, layer): + super(UnquantizedFusedMoEMethod, + self).process_weights_after_loading(layer) + if self.transpose: + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( + 1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, + requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( + 1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + self.transpose = False + else: + w13_data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w13_weight = torch.nn.Parameter(w13_data, + requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data) + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + enable_force_load_balance: bool = False, + shared_experts: Optional[Any] = None, + **kwargs) -> torch.Tensor: + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + shared_experts=shared_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + dynamic_eplb=self.dynamic_eplb) + + +class AscendFusedMoE(FusedMoE): + moe_counter = -1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + num_experts = kwargs["num_experts"] + intermediate_size = kwargs["intermediate_size"] + + AscendFusedMoE.moe_counter += 1 + self.moe_instance_id = AscendFusedMoE.moe_counter + + self.expert_map = None + self.log2phy = None + + if self.quant_config is None: + self.quant_method = AscendUnquantizedFusedMoEMethod( + self.moe_config) + else: + self.quant_method = self.quant_config.get_quant_method( + self, self.layer_name) + + assert self.quant_method is not None + + self.moe_config.tp_group = get_tp_group() + self.moe_config.dp_group = get_dp_group() + self.moe_config.ep_group = get_ep_group() + self.moe_config.mc2_group = get_mc2_group() + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + self.expert_map_path = ascend_config.expert_map_path + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.global_num_experts = num_experts + self.global_redundant_expert_num + # TODO: Flag for static expert placement. This is a temporary workaround + # to allow dynamic EPLB with float weights by skipping quantization checks. + self.static_eplb_enabled = False + if self.custom_routing_function is None and self.e_score_correction_bias is not None: + vllm_config = get_current_vllm_config() + self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( + dtype=vllm_config.model_config.dtype) + # static eplb initializing with expert_map_path + init_eplb_enable = False + if self.expert_map_path and os.path.exists( + self.expert_map_path) and os.access(self.expert_map_path, + os.R_OK): + self.expert_load_balancer = ExpertLoadBalancer( + self.expert_map_path, num_experts) + self.expert_load_balancer.check_expert_map_tensor() + self.global_redundant_expert_num = ( + self.expert_load_balancer.get_global_redundant_expert_num()) + self.global_num_experts = num_experts + self.global_redundant_expert_num + try: + self.local_num_experts, self.expert_map = ( + self.expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank)) + self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, self.ep_rank).npu() + init_eplb_enable = True + except Exception as e: + logger.warning( + f"Init expert map of mtp/eagle when using sample.{e}") + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank).npu() + else: + # init moe. + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) + # dynamic eplb initializing with not expert_map_path + if self.dynamic_eplb: + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank).npu() + if self.expert_map is not None and isinstance(self.expert_map, + torch.Tensor): + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) + local_num_experts = (torch.sum( + self.expert_map != -1) if self.expert_map is not None else + self.global_num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, + dtype=torch.int64).npu() + + if init_eplb_enable and ( + not hasattr(self.quant_method, "quant_method") + or not isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod)): + raise ValueError("Eplb supports only w8a8_dynamic quantization.") + + self.moe_config.num_experts = self.global_num_experts + self.moe_config.num_local_experts = self.local_num_experts + self.moe_config.original_num_experts = num_experts + + moe_quant_params = { + "num_experts": local_num_experts, + "hidden_size": self.hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": self.params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + self.quant_method.create_weights(layer=self, **moe_quant_params) + + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + setup_moe_comm_method(self.moe_config) + + def update_expert_map(self, new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.log2phy + + def clear_moe_load(self): + if self.moe_load is not None: + self.moe_load.zero_() + + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): + """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, + and `alltoallcommimpl`, we do not need to all-reduce the final outputs since + the outputs are already aggregated across tensor parallel ranks in the + `finalize` function. In `allgathercommimpl`, we still need to all-reduce the + outputs since each rank only has partial outputs. + """ + return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. + quantized_x_for_share, dynamic_scale_for_share = None, None + + forward_context = get_forward_context() + + # Load balancing for token distribution among experts in dummy_run + # TODO: The community only considers load balancing when DP > 1. + # This approach may overlook some extreme scenarios. + enable_force_load_balance = forward_context.in_profile_run + + hidden_states, router_logits = forward_context.moe_comm_method.prepare( + hidden_states=hidden_states, + router_logits=router_logits, + replace_allreduce=forward_context.sp_enabled, + enable_shared_expert_dp=self.enable_shared_expert_dp) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + shared_experts=None, + enable_force_load_balance=enable_force_load_balance, + log2phy=self.log2phy, + global_redundant_expert_num=self.global_redundant_expert_num) + + if isinstance(final_hidden_states, tuple): + final_hidden_states, group_list_type, expert_tokens = final_hidden_states + + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type == 1 else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + + final_hidden_states = forward_context.moe_comm_method.finalize( + hidden_states=final_hidden_states, + reduce_results=self.reduce_results) + + return final_hidden_states + + def transpose_weight(self, loaded_weight, expert_data, shard_dim): + # Ensure training and inference weight shapes match during RL weight updates + if ( + loaded_weight.shape[1] != expert_data.shape[1] and \ + loaded_weight.shape[0] != expert_data.shape[0] + ): + shard_dim = int(not shard_dim) + loaded_weight = loaded_weight.transpose(0, 1).contiguous() + return loaded_weight, shard_dim + + def _load_w13(self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + loaded_weight, shard_dim = self.transpose_weight( + loaded_weight, expert_data, shard_dim) + shard_size = expert_data.shape[shard_dim] // 2 + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + loaded_weight, shard_dim = self.transpose_weight( + loaded_weight, expert_data, shard_dim) + shard_size = expert_data.shape[shard_dim] + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + +class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + AscendFusedMoE.__init__(self, **kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + self.shared_expert_stream = None + ascend_config = get_ascend_config() + self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert + if enable_sp(): + logger.info_once( + "Sequence parallelism is enabled, shared experts are replicated for best performance." + ) + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + shared_out, fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + # Make sure the shared experts stream begins after hidden_states are ready. + if self.multistream_overlap_shared_expert: + shared_experts_compute_stream().wait_stream( # type: ignore + torch.npu.current_stream()) + with npu_stream_switch(shared_experts_compute_stream(), + enabled=self.multistream_overlap_shared_expert): + # Use a separate stream to run shared experts. + # Note that currently we only support calculations in separate streams with aclgraph. + # Communication operations in another stream might cause unknown errors. + shared_out = self._shared_experts(hidden_states) + + fused_output = AscendFusedMoE.forward_impl( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + # Make sure the default stream waits for the shared experts stream to finish. + if self.multistream_overlap_shared_expert: + torch.npu.current_stream().wait_stream( + shared_experts_compute_stream()) + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + and not shared_expert_dp_enabled(): + shared_out = tensor_model_parallel_all_reduce(shared_out) + return shared_out, fused_output diff --git a/vllm_npu/ops/expert_load_balancer.py b/vllm_npu/ops/expert_load_balancer.py new file mode 100644 index 0000000..08c0106 --- /dev/null +++ b/vllm_npu/ops/expert_load_balancer.py @@ -0,0 +1,119 @@ +import json +import random +from typing import Dict, List + +import torch +import torch.distributed as dist + + +class ExpertLoadBalancer(object): + + def __init__(self, expert_map_path, num_experts): + self.expert_map_path = expert_map_path + self.num_experts = num_experts + self.tensor_data = [] + self.expert_map_tensor, self.layers_num, self.ranks_num = ( + self._expert_file_to_tensor()) + self.global_expert_num = num_experts + self.get_global_redundant_expert_num( + ) + self.expert_placement_map = self.generate_expert_placement_map() + + def _expert_file_to_tensor(self): + with open(self.expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + self.tensor_data.append(device_data) + expert_map_tensor = torch.tensor(self.tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + + def generate_index_dicts(self, tensor_2d): + dict_list = [] + current_idx = 0 + + for row in tensor_2d: + value_to_index = {} + for i in range(row.size(0)): + value = row[i].item() + value_to_index[value] = current_idx + i + dict_list.append(value_to_index) + current_idx += row.size(0) + + return dict_list + + def generate_expert_placement_map(self): + expert_placement_map = torch.full( + (self.layers_num, self.ranks_num, self.num_experts), + -1, + dtype=torch.int32, + ) + for layer_id in range(self.layers_num): + for gpu_id in range(self.ranks_num): + e_ids = self.expert_map_tensor[layer_id, gpu_id] + expert_placement_map[layer_id, gpu_id, + e_ids] = torch.arange(len(e_ids), + dtype=torch.int32) + return expert_placement_map + + def generate_log2phy_expert_map(self, layer_id): + concatenated = torch.flatten(self.expert_map_tensor[layer_id]) + rank_expert_to_global = self.generate_index_dicts( + self.expert_map_tensor[layer_id]) + result_dict: Dict[int, List[int]] = {} + for idx, value in enumerate(concatenated): + key = value.item() + if key not in result_dict: + result_dict[key] = [] + result_dict[key].append(idx) + + log2phy_map = torch.full((self.ranks_num, self.num_experts), + -1, + dtype=torch.int32) + for rank in range(self.ranks_num): + for key in result_dict: + indices_in_concat = result_dict[key] + if key in rank_expert_to_global[rank]: + log2phy_map[rank][key] = rank_expert_to_global[rank][key] + else: + chosen_index = random.choice(indices_in_concat) + log2phy_map[rank][key] = chosen_index + return log2phy_map + + def get_rank_placement_map(self, layer_id, rank_id): + layer_expert_map = self.expert_placement_map[layer_id] + rank_expert_map = layer_expert_map[rank_id].to( + torch.npu.current_device()) + rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item() + return rank_local_expert_num, rank_expert_map + + def get_rank_log2phy_map(self, layer_id, rank_id): + layer_log2phy_map = self.generate_log2phy_expert_map(layer_id) + return layer_log2phy_map[rank_id] + + def get_global_redundant_expert_num(self): + global_redundant_expert_num = ( + len(self.expert_map_tensor[0][0]) * self.ranks_num - + self.num_experts) + return global_redundant_expert_num + + def check_expert_map_tensor(self): + if dist.is_initialized(): + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + all_expert_maps = [None for _ in range(world_size)] + dist.all_gather_object(all_expert_maps, self.tensor_data) + for rank_id, expert_map_tensor in enumerate(all_expert_maps): + if self.tensor_data != expert_map_tensor: + raise ValueError( + f"The expert map of rank{rank} is not equal to rank{rank_id}" + ) + return True + except Exception as e: + raise ValueError( + f"The expert maps of all ranks are inconsistency: {e}") diff --git a/vllm_npu/ops/fla.py b/vllm_npu/ops/fla.py new file mode 100644 index 0000000..7903900 --- /dev/null +++ b/vllm_npu/ops/fla.py @@ -0,0 +1,299 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. +# mypy: ignore-errors + +import torch +import torch.nn.functional as F +from vllm.triton_utils import tl, triton + +MAX_CORES = 65535 + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X_base + N, # number of columns in X_base + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + N_CORES: tl.constexpr, +): + # Map the program id to the row of X_base and Y_base it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + + BLOCK_ROWS = M if M < N_CORES else N_CORES + n_iters = M // BLOCK_ROWS + remain = M % BLOCK_ROWS + if row < remain: + n_iters = n_iters + 1 + + for i in tl.range(n_iters): + X_base = X + (i * BLOCK_ROWS * + stride_x_row) + row * stride_x_row + group * N + Y_base = Y + (i * BLOCK_ROWS * + stride_y_row) + row * stride_y_row + group * N + if HAS_Z: + Z_base = Z + (i * BLOCK_ROWS * + stride_z_row) + row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean_base = Mean + (i * BLOCK_ROWS) + group * M + Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M + W_base = W + group * N + if HAS_BIAS: + B_base = B + group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean_base + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd_base + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W_base + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B_base + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z_base + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y_base + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N, ) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N, ) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + if not is_rms_norm else None) + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M if M < MAX_CORES else MAX_CORES, ngroups) + with torch.npu.device(x.device.index): + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + N_CORES=MAX_CORES, + num_warps=num_warps, + ) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + return y.reshape(x_shape_og) + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - num_heads % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_heads = num_heads + pad_size + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - + g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -( + (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = (torch.zeros(batch_size, sequence_length, + k_head_dim, v_head_dim).to(value) if + initial_state is None else initial_state.to(value)) + + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=1) + + # for each chunk + for i in range(0, tot_heads // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * + decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * + (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2) @ v_new) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1, + core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = core_attn_out.transpose(1, + 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state diff --git a/vllm_npu/ops/layernorm.py b/vllm_npu/ops/layernorm.py index 96b0cd9..6e29fbc 100644 --- a/vllm_npu/ops/layernorm.py +++ b/vllm_npu/ops/layernorm.py @@ -1,36 +1,213 @@ -""" -NPU-optimized layer normalization for Ascend. +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# -Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with -``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route -to NPU kernels automatically. -""" - -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, cast import torch -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm + + +def _addrmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: torch.Tensor, + layer: Optional[torch.nn.Module] = None, + bias: Optional[torch.nn.Parameter] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + from vllm_npu.utils import is_310p + + if layer is not None and not is_310p(): + layer_cls_name = layer.__class__.__name__ + try: + weight_prefetch_method = get_forward_context( + ).weight_prefetch_method + except AssertionError: + weight_prefetch_method = None + + # prefetch qkvo_proj.weight preprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) + # add_rms_norm_quant + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + beta=bias, + epsilon=self.variance_epsilon) + + # prefetch qkvo_proj.weight postprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + layer_cls_name=layer_cls_name, + stop_flag=x, + ) + + else: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + if bias is not None: + x.add_(bias) + torch.ops.vllm.maybe_wait_prefetch_done(x) + return x, residual class AscendRMSNorm(RMSNorm): - """RMSNorm using Ascend NPU fused kernels. - Uses ``torch_npu.npu_rms_norm`` for standalone normalization and - ``torch_npu.npu_add_rms_norm`` for fused residual-add + norm. - """ + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + vllm_config = get_current_vllm_config() + self.bias = None + # quantization with anti_method m4 will generate none-zero norm bias + if vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - import torch_npu # noqa: F401 + import torch_npu if residual is not None: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon - ) + assert x.size(0) == residual.size(0) + x, residual = _addrmsnorm_forward_oot( + self, x, residual, self.next_need_quant_fusion_linear, + self.bias) + return x, residual + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + if self.bias is not None: + x.add_(self.bias) + return x + + @property + def next_need_quant_fusion_linear(self): + try: + forward_context = get_forward_context() + if not forward_context.addrmsnorm_quant_fusion_enabled or \ + forward_context.layer_idx == forward_context.num_hidden_layers: + return None + except AssertionError: + return None + + next_linear = None + model_instance = forward_context.model_instance + layer_idx = forward_context.layer_idx + fusion_linear = forward_context.fusion_linear + next_linear = None + if fusion_linear == "qkv_dense": + next_linear = model_instance.model.layers[ + layer_idx].self_attn.qkv_proj + forward_context.fusion_linear = "gate_up_dense" + elif fusion_linear == "gate_up_dense": + next_linear = model_instance.model.layers[ + layer_idx].mlp.gate_up_proj + forward_context.fusion_linear = "qkv_dense" + # if prefetch_mlp_weight enabled, following accumulation operation + # does not need to be repeated + if not forward_context.prefetch_mlp_enabled: + forward_context.layer_idx += 1 + elif fusion_linear == "qkv_moe": + next_linear = model_instance.model.layers[ + layer_idx].self_attn.qkv_proj + forward_context.fusion_linear = "gate_moe" + elif fusion_linear == "gate_moe": + forward_context.fusion_linear = "qkv_moe" + forward_context.layer_idx += 1 + from vllm_npu.quantization.w8a8 import AscendW8A8LinearMethod + if next_linear is not None and \ + not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): + next_linear = None + return next_linear + + +class AscendQuantRMSNorm(AscendRMSNorm): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x, residual = super().forward_oot(x, residual) + return x.add_(self.bias), residual + return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) + + +class AscendGemmaRMSNorm(GemmaRMSNorm): + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + from vllm_npu.utils import is_310p + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, 1.0 + self.weight, self.variance_epsilon) return x, residual - x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, + self.variance_epsilon) return x diff --git a/vllm_npu/ops/linear.py b/vllm_npu/ops/linear.py new file mode 100644 index 0000000..b015bd9 --- /dev/null +++ b/vllm_npu/ops/linear.py @@ -0,0 +1,466 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +To customize linear communication groups or forward of classes in this file, +extend new linear operations in linear_op.py. +The classes in this file should not be modified, including AscendQKVParallelLinear, +AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear, +AscendRowParallelLinear and AscendColumnParallelLinear. +""" + +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch_npu +from torch.nn.parameter import Parameter +from vllm.config import get_current_vllm_config +from vllm.distributed import divide +from vllm.model_executor.layers.linear import ( # noqa + WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, + MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase, + ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig +from vllm.model_executor.utils import set_weight_attrs + +from vllm_npu.ops.linear_op import get_parallel_op, get_replicated_op +from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz + + +class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): + """Linear method without quantization""" + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + if (is_enable_nz(layer.weight.data.dtype)): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + + +# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group +class AscendLinearBase(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + nn.Module.__init__(self) + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.quant_config = quant_config + self.prefix = prefix + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = AscendUnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) + self.return_bias = return_bias + self.disable_tp = disable_tp + + +class AscendQKVParallelLinear(QKVParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self, + "column") + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + AscendColumnParallelLinear.__init__(self, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) + + +class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Use the MLP tensor parallelism group in the MLP module, + and the original TP group in other modules. + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "column") + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + self.output_sizes = output_sizes + assert all(output_size % self.tp_size == 0 + for output_size in output_sizes) + AscendColumnParallelLinear.__init__(self, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) + + +class AscendRowParallelLinear(RowParallelLinear): + """Linear layer with row parallelism. + Use the MLP tensor parallelism group in the MLP module, + and the original TP group in other modules. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + compilation_config = get_current_vllm_config().compilation_config + # TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete. + # TODO(MengqingCao): Remove the empty string check, after specifying the prefix in linear layers of some models in the vLLM. + if prefix in compilation_config.static_forward_context and \ + prefix != "" and \ + "visual" not in prefix: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "row") + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + # Divide the weight matrix along the first dimension. + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + if self.custom_op is not None: + self.custom_op.update_attrs() + + def forward( + self, + input_, + is_prefill: bool = True, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) + + +class AscendColumnParallelLinear(ColumnParallelLinear): + """Linear layer with column parallelism. + + Use the MLP tensor parallelism group in the MLP module, + and the original TP group in other modules. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "column") + # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + if self.custom_op is not None: + self.custom_op.update_attrs() + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) + + +class AscendReplicatedLinear(ReplicatedLinear): + """Ascend Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + disable_tp: Take no effect for replicated linear layers. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + self.custom_op = get_replicated_op(disable_tp, prefix, self) + # If MergedReplicatedLinear, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = self.output_sizes + else: + self.output_partition_sizes = [output_size] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + if self.custom_op is not None: + self.custom_op.update_attrs() + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) diff --git a/vllm_npu/ops/linear_op.py b/vllm_npu/ops/linear_op.py new file mode 100644 index 0000000..21b6453 --- /dev/null +++ b/vllm_npu/ops/linear_op.py @@ -0,0 +1,531 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file extends the functionality of linear operations by encapsulating custom +communication groups and forward functions into classes (linear ops). + +Current class inheritance structure: +CustomLinearOp +├── CustomColumnParallelOp +│ ├── MLPColumnParallelOp +│ ├── SequenceColumnParallelOp +└── CustomRowParallelOp +│ ├── MLPRowParallelOp +│ ├── OProjRowParallelOp +│ ├── MatmulAllreduceRowParallelOp +│ └── SequenceRowParallelOp +└── CustomReplicatedOp +How to extend a new linear op? Taking column parallel op as an example: +1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp +2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method +3. Override the apply method according to requirements, which will replace the original linear.forward +4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments +Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. +""" + +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch_npu +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from vllm.distributed import (split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.distributed.parallel_state import get_tp_group +from vllm.forward_context import get_forward_context + +from vllm_npu.distributed.parallel_state import (get_mlp_tp_group, + get_otp_group) +from vllm_npu.utils import (dense_optim_enable, enable_sp, + matmul_allreduce_enable, mlp_tp_enable, + oproj_tp_enable, shared_expert_dp_enabled) + + +class CustomLinearOp: + + def __init__(self, layer): + self.layer = layer + self.bias = None + self.skip_bias_add = None + self.return_bias = None + self.quant_method = None + + # Custom communication group, while determining weight sharding + @property + def comm_group(self): + return get_tp_group() + + @property + def tp_rank(self): + return self.comm_group.rank_in_group + + @property + def tp_size(self): + return self.comm_group.world_size + + # Update the attributes required by apply(), obtaining them from the layer. + # Call this after the layer completes its initialization, specifically at the end of layer.init(). + def update_attrs(self): + if hasattr(self.layer, "bias"): + self.bias = self.layer.bias + self.skip_bias_add = self.layer.skip_bias_add + self.return_bias = self.layer.return_bias + self.quant_method = self.layer.quant_method + self.prefix = self.layer.prefix + + def apply_impl(self, input_): + raise NotImplementedError + + # Replace layer.forward to customize the layer computation process. + def apply(self, input_): + output, output_bias = self.apply_impl(input_) + if not self.return_bias: + return output + return output, output_bias + + +class CustomColumnParallelOp(CustomLinearOp): + + def __init__(self, layer): + super().__init__(layer) + self.gather_output = None + + def update_attrs(self): + super().update_attrs() + self.gather_output = self.layer.gather_output + + +class CustomRowParallelOp(CustomLinearOp): + + def __init__(self, layer): + super().__init__(layer) + self.reduce_results = None + self.input_is_parallel = None + self.input_size_per_partition = None + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.reduce_results = self.layer.reduce_results + self.input_size_per_partition = self.layer.input_size_per_partition + + def apply(self, input_): + output, output_bias = self.apply_impl(input_) + if dense_optim_enable(): + torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) + if not self.return_bias: + return output + return output, output_bias + + +class CustomReplicatedOp(CustomLinearOp): + + def apply_impl(self, input_): + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + + output = self.quant_method.apply(self.layer, input_, bias) + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + +class MLPColumnParallelOp(CustomColumnParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + return get_mlp_tp_group() + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + input_parallel = self.comm_group.all_gather(input_, 0) + output = self.quant_method.apply(self.layer, input_parallel, bias) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class MLPRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + return get_mlp_tp_group() + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + assert self.quant_method is not None + bias_ = None if (self.tp_rank > 0 + or self.skip_bias_add) else self.layer.bias + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + output = self.comm_group.reduce_scatter(output_parallel, 0) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class OProjRowParallelOp(CustomRowParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + return get_otp_group() + + def apply_impl( + self, + input_: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Prepare tensors for all-to-all communication + local_batch_size = input_parallel.size(0) + chunk_size = self.input_size_per_partition + total_batch_size = local_batch_size * self.tp_size + + # Reshape tensor for efficient cross-device transfer: + # [batch, dim] -> [tp_size, batch, chunk] -> flattened + send_buf = (input_parallel.reshape(-1, + self.tp_size, chunk_size).transpose( + 0, 1).contiguous().view(-1)) + + # Create receive buffer + recv_buf = torch.empty(total_batch_size * chunk_size, + dtype=input_parallel.dtype, + device=input_parallel.device) + + # Perform all-to-all communication + dist.all_to_all_single(recv_buf, + send_buf, + group=self.comm_group.device_group) + input_parallel = recv_buf.view(total_batch_size, chunk_size) + + # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1 + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + + # otp-specific: Combine partial results across devices + output = self.comm_group.reduce_scatter(output_parallel, dim=0) + output = output.view(input_.shape[0], self.layer.output_size) + + # Handle bias return based on configuration + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.input_size_per_partition = self.layer.input_size_per_partition + + +class MatmulAllreduceRowParallelOp(CustomRowParallelOp): + _HCOMM_INFO = None + + def __init__(self, layer): + super().__init__(layer) + self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + """Calculate the output tensor of forward by considering + fusing communication and computation.""" + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + if self.reduce_results and self.tp_size > 1: + output = torch_npu.npu_mm_all_reduce_base(input_parallel, + self.weight_t, + self.hcomm_info, + bias=bias_) + else: + assert self.quant_method is not None + output = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + @classmethod + def get_hcomm_info(cls, group: ProcessGroup) -> str: + """Get the HCCL communication information for the given group.""" + if cls._HCOMM_INFO is not None: + return cls._HCOMM_INFO + + rank = torch.distributed.get_rank(group) + if torch.__version__ > "2.0": + global_rank = torch.distributed.get_global_rank(group, rank) + cls._HCOMM_INFO = group._get_backend( + torch.device("npu")).get_hccl_comm_name(global_rank) + else: + cls._HCOMM_INFO = group.get_hccl_comm_name(rank) + return cls._HCOMM_INFO + + def update_attrs(self): + super().update_attrs() + self.weight_t = self.layer.weight.t() + + +class SequenceColumnParallelOp(CustomColumnParallelOp): + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + output_parallel = self.quant_method.apply(self.layer, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class SequenceRowParallelOp(CustomRowParallelOp): + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + if self.input_is_parallel: + input_parallel = input_ + else: + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + assert self.quant_method is not None + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + + if self.tp_size == 1 or not self.reduce_results: + output = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + else: + output = torch.ops.vllm.matmul_and_reduce(input_parallel, + self.prefix) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def matmul_and_reduce(self, input_parallel: torch.Tensor, + bias_: Optional[Parameter]) -> torch.Tensor: + assert self.quant_method is not None + try: + forward_context = get_forward_context() + sp_enabled = forward_context.sp_enabled + mmrs_fusion = forward_context.mmrs_fusion + except AssertionError: + sp_enabled = False + mmrs_fusion = False + + x = input_parallel + + if not sp_enabled: + output_parallel = self.layer.quant_method.apply(self.layer, + x, + bias=bias_) + return tensor_model_parallel_all_reduce(output_parallel) + + pad_size = forward_context.pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + + world_size = self.layer.tp_size + comm_mode = "aiv" + hcom_name = get_tp_group().device_group._get_backend( + torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank) + + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + + from vllm_npu.quantization.quant_config import AscendLinearMethod + from vllm_npu.quantization.w8a8 import (AscendW8A8LinearMethod, + quant_per_tensor) + + # For unquant + if mmrs_fusion and isinstance(self.layer.quant_method, + UnquantizedLinearMethod): + output = torch_npu.npu_mm_reduce_scatter_base( + x, + self.layer.weight.t(), + hcom_name, + world_size, + reduce_op="sum", + bias=None, + comm_turn=0, + comm_mode=comm_mode) + if bias_ is not None: + output.add_(bias_) + # For w8a8 quant + elif mmrs_fusion and ( + isinstance(self.layer.quant_method, AscendLinearMethod) + and isinstance(self.layer.quant_method.quant_method, + AscendW8A8LinearMethod)): + if x.dtype != torch.int8: + x_quant = quant_per_tensor( + x, self.layer.aclnn_input_scale_reciprocal, + self.layer.aclnn_input_offset) + else: + x_quant = x + quant_bias = self.layer.quant_bias + deq_scale = self.layer.deq_scale + output_dtype = torch.bfloat16 + output = torch_npu.npu_mm_reduce_scatter_base( + x_quant, + self.layer.weight, + hcom_name, + world_size, + reduce_op="sum", + bias=None, + comm_turn=0, + x2_scale=deq_scale, + output_dtype=output_dtype, + comm_mode=comm_mode) + output = torch.add( + output, + torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype)) + else: + output_parallel = self.layer.quant_method.apply(self.layer, + x, + bias=bias_) + output = tensor_model_parallel_reduce_scatter(output_parallel, 0) + + return output + + def update_attrs(self): + super().update_attrs() + self.input_is_parallel = self.layer.input_is_parallel + self.reduce_results = self.layer.reduce_results + + +def _get_column_parallel_op( + prefix, layer +) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: + if mlp_tp_enable() and "gate_up_proj" in prefix: + return MLPColumnParallelOp(layer) + if enable_sp(): + if "shared_expert" in prefix: + return None + if "gate_up_proj" in prefix: + return SequenceColumnParallelOp(layer) + if "in_proj" in prefix: + return SequenceColumnParallelOp(layer) + if "qkv_proj" in prefix or "conv1d" in prefix: + return SequenceColumnParallelOp(layer) + + return None + + +def _get_row_parallel_op( + prefix, layer +) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, + MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]: + if "down_proj" in prefix and mlp_tp_enable(): + return MLPRowParallelOp(layer) + if "o_proj" in prefix and oproj_tp_enable(): + return OProjRowParallelOp(layer) + if matmul_allreduce_enable(): + return MatmulAllreduceRowParallelOp(layer) + if enable_sp(): + if "shared_expert" in prefix: + return None + if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix: + return SequenceRowParallelOp(layer) + + return None + + +def get_parallel_op(disable_tp, prefix, layer, direct): + if disable_tp or ("shared_experts" in prefix + and shared_expert_dp_enabled()): + return None, 0, 1 + custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, + MLPRowParallelOp, OProjRowParallelOp, + MatmulAllreduceRowParallelOp, + SequenceRowParallelOp]] = None + if direct == "row": + custom_op = _get_row_parallel_op(prefix, layer) + + if direct == "column": + custom_op = _get_column_parallel_op(prefix, layer) + + if custom_op is not None: + return custom_op, custom_op.tp_rank, custom_op.tp_size + + return None, get_tp_group().rank_in_group, get_tp_group().world_size + + +def get_replicated_op(disable_tp, prefix, + layer) -> Optional[Union[CustomReplicatedOp]]: + if disable_tp: + return None + + return CustomReplicatedOp(layer) diff --git a/vllm_npu/ops/moe/__init__.py b/vllm_npu/ops/moe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/ops/moe/comm_utils.py b/vllm_npu/ops/moe/comm_utils.py new file mode 100644 index 0000000..b8952a9 --- /dev/null +++ b/vllm_npu/ops/moe/comm_utils.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +import torch.distributed +import torch.distributed as dist +import torch_npu + +COMM_STREAM = None + + +def async_all_to_all(input_, + output_split_sizes, + input_split_sizes, + group, + event=None): + if output_split_sizes is None: + # Equal split (all2all) + a2a_out = torch.empty_like(input_) + else: + # Unequal split (all2all-v) + a2a_out = input_.new_empty( + size=[sum(output_split_sizes)] + list(input_.size()[1:]), + dtype=input_.dtype, + device=torch.npu.current_device(), + ) + + if event: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + event.wait() + handle = dist.all_to_all_single( + a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + else: + handle = dist.all_to_all_single(a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + return input_, a2a_out, handle + + +def _gather_along_first_dim(input_, group, output_split_sizes=None): + """Gather tensors and concatenate along the first dimension. + + Args: + input_tensor (torch.Tensor): + A tensor to be gathered. + output_split_sizes (List[int], optional): + A list specifying the sizes of the output splits along the first dimension. + If None, equal splitting is assumed. Default: None. + + Returns: + torch.Tensor: Gathered tensor. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + if output_split_sizes is None: + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + else: + dim_size[0] = sum(output_split_sizes) + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + output_tensor_list = list( + torch.split(output, output_split_sizes, dim=0)) + torch.distributed.all_gather(output_tensor_list, input_, group=group) + + return output + + +def gather_from_sequence_parallel_region( + input_, + group, + output_split_sizes=None, +): + """Wrapper for autograd function: forward: AG, backward: RS """ + return _gather_along_first_dim(input_, group, output_split_sizes) \ No newline at end of file diff --git a/vllm_npu/ops/moe/experts_selector.py b/vllm_npu/ops/moe/experts_selector.py new file mode 100644 index 0000000..0f940f3 --- /dev/null +++ b/vllm_npu/ops/moe/experts_selector.py @@ -0,0 +1,277 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Callable, Optional + +import torch +import torch_npu +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_config import get_ascend_config + + +def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor=1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + global_num_experts: int = -1): + """ + Fused experts with select experts. + + Args: + router_logits: router logits of shape (num_tokens, hidden_size). + hidden_states: Hidden states of shape (num_tokens, hidden_size). + top_k: number of top k experts. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + indices_type: dtype of indices + global_num_experts: Global number of experts. + + Returns: + topk_weights: router weights of shape (num_tokens, top_k). + topk_ids: selected expert IDs of shape (num_tokens, top_k). + """ + # prefetch w1_w3_proj.weight preprocess + weight_prefetch_method = get_forward_context().weight_prefetch_method + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( + hidden_states, "gate_up") + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + + if topk_weights is None: + topk_weights, topk_ids = _native_select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + return topk_weights, topk_ids + + +def _native_grouped_topk( + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], +): + topk_group = 0 if topk_group is None else topk_group + num_expert_group = 0 if num_expert_group is None else num_expert_group + + num_token = topk_weights.shape[0] + grouped_weights = topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + topk_group_mask = torch.zeros_like(grouped_weights) + topk_group_mask.scatter_(1, topk_group_indices, 1) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) + + return topk_weights + + +def _renormalize_topk_weights( + topk_weights: torch.Tensor, + renormalize: bool, +): + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights + + +def _select_expert_use_group_topk( + topk_weights: torch.Tensor, topk_group: Optional[int], + renormalize: bool, top_k: int, num_expert_group: Optional[int], + e_score_correction_bias: Optional[torch.Tensor]): + assert topk_group is not None + assert num_expert_group is not None + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_weights = topk_weights + topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) + + # TODO: Change to npu_group_topk when the latest CANN and NNAL is available + # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) + topk_weights = _native_grouped_topk(topk_weights, num_expert_group, + topk_group) + # TODO bfloat16 is not supported in torch.topk with ge graph. + if e_score_correction_bias is not None: + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_weights.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + return topk_weights, topk_ids + + +def _select_experts_with_fusion_ops( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + e_score_correction_bias: Optional[torch.Tensor], + topk_group: Optional[int], + num_expert_group: Optional[int], + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor=1.0, + global_num_experts: int = -1): + + topk_weights, topk_ids = None, None + # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern + global_redundant_expert_num = get_ascend_config().init_redundancy_expert + is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 + if is_deepseek_v3_r1: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + x=router_logits, finished=None, k=top_k) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + + return topk_weights, topk_ids + + +def _native_select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + + if scoring_func == "softmax": + topk_weights = router_logits.softmax(dim=-1) + elif scoring_func == "sigmoid": + topk_weights = router_logits.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if use_grouped_topk: + return _select_expert_use_group_topk( + topk_weights=topk_weights, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias) + + if custom_routing_function is not None: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids + + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + topk_weights = topk_weights.to(hidden_states.dtype) + + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + + return topk_weights, topk_ids diff --git a/vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py new file mode 100644 index 0000000..b3e0b30 --- /dev/null +++ b/vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py @@ -0,0 +1,520 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from abc import ABC, abstractmethod + +import torch +import torch.distributed as dist +import torch.nn as nn +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_dp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_npu.utils import enable_sp + + +class FusedMoEPrepareAndFinalize(ABC): + """ + Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization + in distributed environments. Subclasses implement specific communication strategies + (e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing, + broadcasting, and reduction across TP/DP/EP groups. + + Attributes: + moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info, + sizes, ranks, and communication settings. + """ + + def __init__(self, moe_config: FusedMoEConfig): + self.moe_config = moe_config + + @abstractmethod + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare tensors before MoE computation. May involve: + - Padding to align communication boundaries + - Slicing across tensor-parallel ranks + - Broadcasting across data-parallel ranks + + Args: + hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size] + router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] + enable_shared_expert_dp (bool): Skip DP communication for shared experts + replace_allreduce (bool): Bypass default all-reduce behavior + + Returns: + Tuple of: + - processed hidden_states (may be padded/sliced/broadcasted) + - processed router_logits (may be recomputed or broadcasted) + - optional communication mask (e.g., mc2_mask for sparse ops) + """ + raise NotImplementedError("Prepare not implemented.") + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalize MoE output. May involve: + - Gathering sliced tensors across TP ranks + - Reducing or scattering across DP ranks + - Unpadding to original token count + - Applying all-reduce across TP/EP if requested + + Args: + hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced + reduce_results (bool): Whether to apply all-reduce across TP/EP groups + + Returns: + torch.Tensor: Final output with shape [original_num_tokens, hidden_size] + """ + raise NotImplementedError("Finalize function not implemented.") + + +class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using MC2 (Memory-Centric Communication). + Designed for Ascend or environments requiring explicit padding and slicing control. + Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. + """ + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + """ + Restore original TP configuration. + vLLM flattens TP and DP into a single dimension; this method recovers + the true TP world size and rank for correct tensor slicing. + """ + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch `mc2_mask` and target padding length from forward context. + 2. Pad `hidden_states` and `router_logits` to target length if needed. + 3. If TP > 1, split tensors along token dimension and select current TP rank's slice. + 4. Split and return corresponding `mc2_mask`. + + Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + forward_context = get_forward_context() + mc2_mask = forward_context.mc2_mask + if self.tp_size > 1: + # Also slice mc2_mask + split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) + mc2_mask = split_mc2_mask[self.tp_rank] + + if not self.replace_allreduce: + self.num_tokens, _ = hidden_states.shape + target_pad_length = forward_context.padded_num_tokens + pad_size = target_pad_length - self.num_tokens + + # Pad if necessary (unless shared expert DP is enabled) + if pad_size > 0 and not self.enable_shared_expert_dp: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + # Slice across TP ranks + if self.tp_size > 1 and not self.enable_shared_expert_dp: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + self.split_hidden_states = split_hidden_states # Save for finalize + + return hidden_states, router_logits, mc2_mask + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor. + 2. Unpad to original token count if padding was applied. + 3. Return tensor with shape [original_num_tokens, hidden_size]. + + Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + # All-gather across TP group + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + # TODO: It is a quick bugfix for the memory explosion issue in eager mode. + # If the cache is not cleared after `self.split_hidden_states` is created, + # it can lead to the memory explosion in eager mode. + del self.split_hidden_states + + # Unpad if necessary + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-to-All style slicing. + Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. + Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). + """ + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + """Restore original TP configuration (same as MC2).""" + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Pad hidden_states and router_logits to next multiple of TP size. + 2. If TP > 1, split along token dim and select current TP rank's slice. + 3. Save splits for later all-gather in finalize. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, None) — no mask used in All2All. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + + if not (self.replace_allreduce or self.enable_shared_expert_dp): + self.num_tokens, _ = hidden_states.shape + pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + self.split_hidden_states = split_hidden_states + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices to reconstruct full tensor. + 2. Unpad to original token count. + 3. Return [original_num_tokens, hidden_size] tensor. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + # TODO: It is a quick bugfix for the memory explosion issue in eager mode. + # If the cache is not cleared after `self.split_hidden_states` is created, + # it can lead to the memory explosion in eager mode. + del self.split_hidden_states + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-Gather + Reduce-Scatter on EP group. + There are two sets of prepare and finalize: + 1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled, + we gather inputs across DP ranks before MoE, scatter outputs after. + The communication and calculation process is as follows (AG, AR and RS + are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively): + + Attn → TP AR → DP AG → MoE → DP RS → TP AR + + 2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled, + the above process becomes: + + TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS + + This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS + into EP Reduce-Scatter to improve communication performance. The optimized process is as follows: + + TP AG → Attn → TP RS → EP AG → MoE → EP RS + """ + + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + AllGather hidden_states and router_logits to form global tensors. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + if enable_sp(): + return self._prepare_with_ep_group(hidden_states, router_logits) + + return self._prepare_with_dp_group(hidden_states, router_logits, + enable_shared_expert_dp, + replace_allreduce) + + def _prepare_with_ep_group( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states, True, True) + router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + router_logits, True, True) + + return hidden_states, router_logits, None + + def _prepare_with_dp_group( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch max token count across DP group from forward context. + 2. Pad local tensors to that size. + 3. All-gather across DP group to form global input tensor. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + self.enable_shared_expert_dp = enable_shared_expert_dp + if self.moe_config.dp_size > 1: + forward_context = get_forward_context() + max_tokens_across_dp = forward_context.max_tokens_across_dp + + self.num_tokens = hidden_states.shape[0] + pad_size = max_tokens_across_dp - self.num_tokens + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + # All-gather across DP group + hidden_states = self.moe_config.dp_group.all_gather( + hidden_states, 0) + router_logits = self.moe_config.dp_group.all_gather( + router_logits, 0) + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + Reduce Scatter hidden states. + + Returns: + Tensor with shape [local_num_tokens, hidden_size] + """ + if enable_sp(): + return self._finalize_with_ep_group(hidden_states) + + return self._finalize_with_dp_group(hidden_states, reduce_results) + + def _finalize_with_ep_group(self, + hidden_states: torch.Tensor) -> torch.Tensor: + """ + Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled: + 1. Reduce_results is False usually happens when models have shared experts and need to + allreduce hidden states after results of shared experts and routed experts are added in FusedMoe. + We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the + result of shared experts. + 2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter + here, then skip allreudce in FusedMoe. + """ + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + hidden_states, True) + + return hidden_states + + def _finalize_with_dp_group(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If DP > 1 and not shared expert, reduce-scatter output across DP group. + 2. Slice to original local token count. + 3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. + + Returns: + Tensor with shape [original_local_num_tokens, hidden_size] + """ + if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: + hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) + hidden_states = hidden_states[:self.num_tokens] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using Naive Multicast (point-to-point broadcast). + Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others. + Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries. + """ + + def _naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + """ + Naive multicast implementation: + 1. Create global buffer sized by total tokens across DP. + 2. Current rank copies its slice into its designated buffer region. + 3. Each rank broadcasts its slice to all others via P2P. + + Args: + x (torch.Tensor): Local tensor [local_tokens, hidden_size] + cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank + + Returns: + torch.Tensor: Global tensor [total_tokens, hidden_size] + """ + assert len(x.shape) == 2, "Input must be 2D [tokens, features]" + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + # Copy local slice into buffer + start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.moe_config.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank] + buffer[start:end, :].copy_(x) + + # Broadcast each slice to all ranks + for idx in range(self.moe_config.dp_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + return buffer + + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch cumulative token boundaries from forward context. + 2. Multicast hidden_states and router_logits to form global tensors. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + self.enable_shared_expert_dp = enable_shared_expert_dp + + if self.moe_config.dp_size > 1: + self.cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_sp(1) + hidden_states = self._naive_multicast(hidden_states, + self.cu_tokens_across_dp_cpu) + router_logits = self._naive_multicast(router_logits, + self.cu_tokens_across_dp_cpu) + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If DP > 1 and not shared expert: + - All-reduce across DP + - Slice to current rank's token range using cu_tokens_across_dp_cpu + 2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. + + Returns: + Tensor with shape [local_num_tokens, hidden_size] + """ + if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: + start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[ + self.moe_config.dp_rank - 1] + end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank] + hidden_states = get_dp_group().all_reduce( + hidden_states) # Sum across DP + hidden_states = hidden_states[start:end, :] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states diff --git a/vllm_npu/ops/moe/moe_comm_method.py b/vllm_npu/ops/moe/moe_comm_method.py new file mode 100644 index 0000000..ea47b86 --- /dev/null +++ b/vllm_npu/ops/moe/moe_comm_method.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import torch +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_npu.ascend_forward_context import MoECommType +from vllm_npu.ops.moe.fused_moe_prepare_and_finalize import ( + FusedMoEPrepareAndFinalizeWithAll2All, + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, + FusedMoEPrepareAndFinalizeWithNaiveMulticast) +from vllm_npu.ops.moe.moe_mlp import unified_apply_mlp +from vllm_npu.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, + TokenDispatcherWithMC2, + TokenDispatcherWithMoge) + +_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} + + +def get_moe_comm_method( + moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]: + return _MoECommMethods.get(moe_comm_type, None) + + +def setup_moe_comm_method(moe_config): + _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) + _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) + _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) + _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( + moe_config) + + +class MoECommMethod(ABC): + """Base class for MoE communication methods.""" + + def __init__(self, moe_config: FusedMoEConfig): + self.model_type = get_current_vllm_config( + ).model_config.hf_config.model_type + self.moe_config = moe_config + self.mc2_mask = None + + self.token_dispatcher = self._get_token_dispatcher() + self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( + ) + + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( + hidden_states, router_logits, enable_shared_expert_dp, + replace_allreduce) + self.mc2_mask = mc2_mask + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + hidden_states = self.fused_moe_prepare_finalize.finalize( + hidden_states, reduce_results) + return hidden_states + + def fused_experts( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + # For TorchAir graph + is_torchair: bool = False, + # For Cube/Vector parallel + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + # For load balance + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + need_trans: bool = False, + dynamic_eplb: bool = False): + # Check constraints + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + moe_comm_method = get_forward_context().moe_comm_method + assert moe_comm_method is not None, "Missing communication context" + + results = self.token_dispatcher.token_dispatch( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + mc2_mask=self.mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + with_quant=use_int8_w8a8 or use_int4_w4a8, + dynamic_eplb=dynamic_eplb) + + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") + + mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=expert_tokens, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=topk_scales, + with_quant=use_int8_w8a8 + or use_int4_w4a8, + fusion=use_int8_w8a8, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb) + + final_hidden_states = self.token_dispatcher.token_combine( + hidden_states=mlp_output) + + if dynamic_eplb: + return (final_hidden_states, group_list_type, expert_tokens) + + return final_hidden_states + + @abstractmethod + def _get_token_dispatcher(self): + raise NotImplementedError( + "_get_token_dispatcher function not implemented.") + + @abstractmethod + def _get_fused_moe_prepare_finalize(self): + raise NotImplementedError( + "_get_fused_moe_prepare_finalize function not implemented.") + + +class AllGatherCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _get_token_dispatcher(self): + if self.model_type == "PanguProMoE": + return TokenDispatcherWithMoge( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + else: + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + + +class MC2CommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. + 3. `enable_expert_parallel=False` is not supported. + + This implementation uses the MC2 communication method, which is optimized for + Communication and Computation parallelism on Ascend devices. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithMC2() + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + + +class AlltoAllCommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_grouped_matmul` is available. + + This implementation uses all-to-all communication to exchange tokens + between data parallel ranks before and after the MLP computation. It should + have better performance than AllGatherCommImpl when DP size > 1. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAll2AllV( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + + +class NaiveMulticastCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) diff --git a/vllm_npu/ops/moe/moe_mlp.py b/vllm_npu/ops/moe/moe_mlp.py new file mode 100644 index 0000000..4e35f35 --- /dev/null +++ b/vllm_npu/ops/moe/moe_mlp.py @@ -0,0 +1,266 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from typing import Optional + +import torch +import torch_npu +from torch.nn.functional import pad +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_forward_context import MoECommType +from vllm_npu.utils import dispose_tensor, is_310p + + +def cumsum_group_list(group_list: torch.Tensor, + src_list_type: int, + dst_list_type: int, + active_num: int = 0, + expert_num: int = 0) -> torch.Tensor: + if src_list_type not in [0, 1, 2]: + raise ValueError( + f"group_list_type should be in [0, 1, 2], but received {src_list_type}" + ) + + if src_list_type == dst_list_type: + return group_list + if src_list_type == 1 and dst_list_type == 0: + return group_list.cumsum(dim=0) + if src_list_type == 0 and dst_list_type == 1: + group_diff = torch.diff(group_list) + new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0) + return new_group + if src_list_type == 2 and dst_list_type == 0: + experts = pad(group_list[:, 0], (1, 0)) + tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) + cumsum_group_list = torch.full(size=(expert_num, ), + fill_value=active_num, + dtype=group_list.dtype, + device=group_list.device) + + for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): + if end > start: + cumsum_group_list[start:end] = tokens[i] + + return cumsum_group_list + raise NotImplementedError( + f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. " + "This feature is under development.") + + +def quant_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + dynamic_scale: torch.Tensor = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + fusion: bool = False, + dynamic_eplb: bool = False) -> torch.Tensor: + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + weight_prefetch_method = get_forward_context().weight_prefetch_method + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( + hidden_states) + is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 + if w1_scale_bias is None and is_mc2: + if fusion and not dynamic_eplb: + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + x=hidden_states, + weight=w1, + group_list=cumsum_group_list(group_list, group_list_type, 0), + weight_scale=w1_scale, + x_scale=pertoken_scale) + else: + if w1_scale.dtype != torch.float32: + w1_scale = w1_scale.to(torch.float32) + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=cumsum_group_list(group_list, group_list_type, 1), + activate_left=True, + quant_mode=1, + ) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + else: + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], + torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] if not fusion else w1_scale_bias + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + + if fusion and not dynamic_eplb: + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + x=hidden_states, + weight=w1, + bias=bias1, + group_list=cumsum_group_list(group_list, group_list_type, 0), + weight_scale=w1_scale, + x_scale=pertoken_scale) + else: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale.to(w2_scale.dtype)], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + bias=bias2, + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + return hidden_states + + +def unquant_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + topk_scales: Optional[torch.Tensor] = None, + need_trans: bool = True) -> torch.Tensor: + + if need_trans: + w1 = w1.transpose(1, 2) + w2 = w2.transpose(1, 2) + + gate_up_out = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + if topk_scales is not None: + gate_up_out *= topk_scales + + hidden_states = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + return hidden_states + + +def unified_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + topk_scales: Optional[torch.Tensor] = None, + with_quant: bool = False, + fusion: bool = False, + need_trans: bool = True, + dynamic_eplb: bool = False) -> torch.Tensor: + if with_quant: + return quant_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + fusion=fusion, + dynamic_eplb=dynamic_eplb) + else: + return unquant_apply_mlp(hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + group_list_type=group_list_type, + topk_scales=topk_scales, + need_trans=need_trans) \ No newline at end of file diff --git a/vllm_npu/ops/moe/token_dispatcher.py b/vllm_npu/ops/moe/token_dispatcher.py new file mode 100644 index 0000000..dac687c --- /dev/null +++ b/vllm_npu/ops/moe/token_dispatcher.py @@ -0,0 +1,725 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +import torch_npu +from vllm.distributed.parallel_state import get_ep_group + +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.ops.moe.comm_utils import ( + async_all_to_all, gather_from_sequence_parallel_region) +from vllm_npu.utils import (AscendSocVersion, get_ascend_soc_version, + is_hierarchical_communication_enabled) + + +class MoETokenDispatcher(ABC): + + def __init__(self, **kwargs) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self.top_k = kwargs.get("top_k", 0) + self.num_experts = kwargs.get("num_experts", 0) + + @property + def ep_group(self): + """Get expert model parallel group.""" + return get_ep_group().device_group + + @property + def ep_rank(self): + return get_ep_group().rank_in_group + + @property + def ep_size(self): + return get_ep_group().world_size + + @abstractmethod + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False): + raise NotImplementedError("Dispatch function not implemented.") + + @abstractmethod + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + raise NotImplementedError("Combine function not implemented.") + + +class TokenDispatcherWithMC2(MoETokenDispatcher): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + self.ep_rank_id = get_mc2_group().rank_in_group + self.ep_world_size = get_mc2_group().world_size + self.enable_dispatch_v2 = hasattr(torch_npu, + "npu_moe_distribute_dispatch_v2") + self.need_extra_args = ( + get_ascend_soc_version() == AscendSocVersion.A3) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + self.a3_need_extra_args = \ + get_ascend_soc_version() == AscendSocVersion.A3 + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + self.need_expert_scale = is_hierarchical_communication_enabled() + self.output = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.shared_act = None + self.topk_ids = None + self.topk_weights = None + self.shared_experts = None + self.mc2_mask = None + self.with_quant = False + self.expand_scales = None + + def get_dispatch_mc2_kwargs( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + global_redundant_expert_num: int = 0, + ): + quant_mode = 2 if self.with_quant else 0 + self.moe_expert_num = len(expert_map) + global_redundant_expert_num + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": self.moe_expert_num, + "global_bs": 0, + "expert_token_nums_type": 0, + } + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": self.moe_all_to_all_group_name, + "ep_world_size": self.ep_world_size, + "ep_rank_id": self.ep_rank_id, + } + if self.need_extra_args: + stage1_kwargs.update({ + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if self.a3_need_extra_args and self.enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": self.mc2_mask, + }) + if self.need_expert_scale: + stage1_kwargs.update({ + "expert_scales": + topk_weights.to(torch.float32), + }) + + kwargs_mc2.update(stage1_kwargs) + return kwargs_mc2 + + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False): + # Apply log2phy if needed + if log2phy is not None: + topk_ids = log2phy[topk_ids] + + self.with_quant = with_quant + self.expert_map = expert_map + self.topk_ids = topk_ids + self.topk_weights = topk_weights + self.shared_experts = shared_experts + self.mc2_mask = mc2_mask + + kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, + topk_ids, expert_map, + global_redundant_expert_num) + self.output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \ + self.ep_recv_counts, _, self.expand_scales = self.output[0:7] + + if self.with_quant: + if shared_experts is not None: + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + shared_act_out = shared_experts.act_fn( + (shared_gate_up, shared_dequant_scale)) + self.shared_act, self.swiglu_out_scale = \ + shared_act_out[0], shared_act_out[1] + + else: + if shared_experts is not None: + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + self.shared_act = shared_experts.act_fn(shared_gate_up) + group_list_type = 0 + return { + "group_list_type": group_list_type, + "hidden_states": expand_x, + "group_list": expert_token_nums, + "dynamic_scale": dynamic_scale, + } + + def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): + assert self.expert_map is not None + assert self.topk_weights is not None + assert self.topk_ids is not None + assert self.output is not None + # moeCombine + kwargs_mc2 = { + "expand_x": hidden_states, + "expert_ids": self.topk_ids, + "expert_scales": self.topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": self.moe_expert_num, + "global_bs": 0, + } + if self.with_quant: + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + else: + tp_recv_counts = self.output[5] + stage3_kwargs = { + "ep_send_counts": self.ep_recv_counts, + "group_ep": self.moe_all_to_all_group_name, + "ep_world_size": self.ep_world_size, + "ep_rank_id": self.ep_rank_id, + "expand_scales": self.expand_scales, + } + if self.enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + self.assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": self.assist_info_for_combine, + }) + if self.need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if self.a3_need_extra_args and self.enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": self.mc2_mask, + }) + kwargs_mc2.update(stage3_kwargs) + return kwargs_mc2 + + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) + + # these values are no longer used, so they need to be set to None for memory release. + self.output = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.topk_ids = None + self.topk_weights = None + self.mc2_mask = None + self.expert_map = None + self.expand_scales = None + + if self.shared_experts is None: + return hidden_states + else: + if self.with_quant: + shared_hidden_states, _ = self.shared_experts.down_proj( + (self.shared_act, self.swiglu_out_scale)) + else: + shared_hidden_states, _ = self.shared_experts.down_proj( + self.shared_act) + self.shared_act = None + self.shared_experts = None + self.swiglu_out_scale = None + return hidden_states, shared_hidden_states + + +class TokenDispatcherWithAllGather(MoETokenDispatcher): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.apply_router_weight_on_input = False + self.max_num_tokens = kwargs.get("max_num_tokens") + self.num_experts_local = kwargs.get("num_local_experts", 0) + self.sorted_weights = None + self.expanded_row_idx = None + self.sorted_token_indices = None + self.original_shape = None + self.mask = None + self.expert_map = None + self.topk_weights = None + self.topk_ids = None + self.with_quant = False + + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False): + self.with_quant = with_quant + self.original_shape = hidden_states.shape + + num_tokens = hidden_states.shape[:-1].numel() + self.expert_map = expert_map + self.topk_weights = topk_weights + self.topk_ids = topk_ids + self.apply_router_weight_on_input = apply_router_weight_on_input + if self.apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * \ + topk_weights.to(hidden_states.dtype) + if expert_map is not None: + global_num_experts = len(expert_map) + global_redundant_expert_num + mask = (expert_map[topk_ids] != -1) + self.topk_weights = topk_weights * mask + first_expert_idx = get_ep_group( + ).rank_in_group * self.num_experts_local + last_expert_idx = first_expert_idx + self.num_experts_local + else: + first_expert_idx = 0 + last_expert_idx = self.num_experts_local + global_num_experts = self.num_experts_local + + sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = ( + torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + active_num=num_tokens * self.top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=1 if self.with_quant else -1, + )) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 # `count` mode + return { + "group_list_type": group_list_type, + "hidden_states": sorted_hidden_states, + "group_list": expert_tokens, + "dynamic_scale": pertoken_scale if self.with_quant else None, + } + + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + assert self.original_shape is not None + final_hidden_states = torch_npu.npu_moe_token_unpermute( + permuted_tokens=hidden_states, + sorted_indices=torch.abs(self.expanded_row_idx), + probs=self.topk_weights) + if len(self.original_shape) == 3: + final_hidden_states = final_hidden_states.view(self.original_shape) + + # these values are no longer used, so they need to be set to None for memory release. + self.expert_map = None + self.topk_weights = None + self.topk_ids = None + self.expanded_row_idx = None + return final_hidden_states + + +# mypy: disable-error-code="override" +class TokenDispatcherWithMoge(MoETokenDispatcher): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.apply_router_weight_on_input = False + self.local_num_experts = self.num_experts // self.ep_size + self.local_num_group = self.top_k // self.ep_size + self.bsz = None + + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False): + self.bsz, _ = hidden_states.shape + flatten_topk_ids = topk_ids.view(-1) + self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) + self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32) + sorted_hidden_states = hidden_states.index_select( + 0, self.sorted_topk_ids // self.local_num_group) + + experts_id = torch.arange(0, + self.local_num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + num_tokens_per_expert = ( + flatten_topk_ids.unsqueeze(-1) == experts_id).to( + torch.float32).sum(0) + topk_scales = topk_weights.view(-1).index_select( + 0, self.sorted_topk_ids).unsqueeze(-1) + group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) + group_list_type = 0 + return { + "group_list_type": group_list_type, + "hidden_states": sorted_hidden_states, + "group_list": group_list, + "topk_scales": topk_scales, + } + + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( + torch.int32) + unsorted_hidden_states = hidden_states.index_select( + 0, unsorted_topk_ids) + final_hidden_states = unsorted_hidden_states.reshape( + self.bsz, self.top_k // self.ep_size, -1).sum(1) + return final_hidden_states + + +class TokenDispatcherWithAll2AllV(MoETokenDispatcher): + """ + The implementation of the AlltoAll-based token dispatcher, which handles token + dispatching on the sequence level instead of token level. The core of this implementation + lies in each device dispatching on the entire sequence, with the hidden state being partitioned. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.with_quant = False + self.num_local_experts = kwargs.get("num_local_experts", 0) + + self.hidden_shape = None + self.topk_weights = None + self.input_splits = None + self.output_splits = None + self.hidden_shape_before_permute = None + + # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = None + + # cached intermediate tensors. + self.tokens_per_expert = None + self.global_input_tokens_local_experts_indices = None + + assert self.num_local_experts > 0, "Expected at least one expert" + if self.num_local_experts > 1: + self.expert_ids_per_ep_rank = torch.tensor( + [i % self.num_local_experts for i in range(self.num_experts)], + dtype=torch.int32, + device=torch.npu.current_device(), + ) + + local_expert_indices_offset = (self.ep_rank * self.num_local_experts) + + self.local_expert_indices = [ + local_expert_indices_offset + i + for i in range(self.num_local_experts) + ] + assert (len(self.local_expert_indices) == self.num_local_experts + ), "Invalid local expert indices" + for i in range(len(self.local_expert_indices) - 1): + assert (self.local_expert_indices[i] == + self.local_expert_indices[i + 1] - + 1), "local_expert_indices must be continuous" + + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False): + self.with_quant = with_quant + self.hidden_shape = hidden_states.shape + self.topk_weights = topk_weights + assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" + assert topk_ids.dim() == 2, "Expected 2D tensor for routing map" + + if log2phy is not None: + topk_ids = log2phy[topk_ids] + + permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess( + hidden_states, topk_ids) + self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping + + dynamic_scale_after_all2all = None + if self.with_quant: + permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( + permutated_local_input_tokens) + + _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( + dynamic_scale, + self.output_splits, + self.input_splits, + self.ep_group, + ) + permute2_ep_all_to_all_handle.wait() + dynamic_scale.untyped_storage().resize_(0) + + _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + self.ep_group, + ) + permute1_ep_all_to_all_handle.wait() + permutated_local_input_tokens.untyped_storage().resize_(0) + + global_input_tokens, dynamic_scale = self._dispatch_postprocess( + global_input_tokens, dynamic_scale_after_all2all) + return { + "hidden_states": global_input_tokens, + "group_list": tokens_per_expert, + "dynamic_scale": dynamic_scale, + "group_list_type": 1 + } + + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." + + hidden_states = self._combine_preprocess(hidden_states) + + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, permutated_local_input_tokens, handle = async_all_to_all( + hidden_states, self.input_splits, self.output_splits, + self.ep_group) + handle.wait() + hidden_states.untyped_storage().resize_(0) + + output = self._combine_postprocess(permutated_local_input_tokens) + + # these values are no longer used, so they need to be set to None for memory release. + self.input_splits = None + self.output_splits = None + self.num_global_tokens_per_local_expert = None + self.topk_weights = None + self.reversed_local_input_permutation_mapping = None + self.reversed_global_input_permutation_mapping = None + self.global_input_tokens_local_experts_indices = None + + return output + + def _dispatch_preprocess(self, hidden_states, topk_ids): + assert self.hidden_shape is not None + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self._preprocess(topk_ids) + + self.hidden_shape_before_permute = hidden_states.shape + + permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=topk_ids, + num_out_tokens=self.num_out_tokens, + ) + return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert + + def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor: + num_local_tokens_per_expert = torch.histc(topk_ids, + bins=self.num_experts, + min=0, + max=self.num_experts) + + ep_size = self.ep_size + + # Dropless + self.num_out_tokens = topk_ids.numel() + + # =================================================== + # Calculate input_splits, output_splits for alltoall-v. + # =================================================== + self.input_splits = (num_local_tokens_per_expert.reshape( + ep_size, + self.num_local_experts).sum(axis=1).to(torch.device("cpu"), + non_blocking=True).numpy()) + num_global_tokens_per_expert = gather_from_sequence_parallel_region( + num_local_tokens_per_expert, + group=self.ep_group).reshape(ep_size, self.num_experts) + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + 0]:self.local_expert_indices[-1] + 1] + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before sum.") + self.output_splits = (self.num_global_tokens_per_local_expert.sum( + axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( + axis=0) + # =================================================== + # num_global_tokens_per_expert: [ep_size, num_experts] + # num_global_tokens_per_local_expert: [ep_size, num_local_experts] + # num_tokens_per_local_expert: [num_local_experts] + # =================================================== + + if self.num_local_experts > 1: + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) + else: + # TODO: This full synchronization can be a performance bottleneck. + # A more granular sync (e.g., blocking D2H copies) should be investigated. + torch.npu.synchronize() + + return num_tokens_per_local_expert + + def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None): + # Early return if no local experts or no tokens + if self.num_local_experts <= 1: + return global_input_tokens, None + + # Handle quantized case + if self.with_quant: + assert self.global_input_tokens_local_experts_indices is not None, \ + "global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess" + expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze( + -1) + active_num = self.global_input_tokens_local_experts_indices.numel() + + # Handle case with no active tokens + if active_num <= 0: + self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices + return global_input_tokens, dynamic_scale + + # Process with active tokens + global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( + global_input_tokens, + expert_idx_2d, + scale=dynamic_scale, + active_num=active_num, + expert_capacity=0, + expert_num=self.num_local_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[0, self.num_local_experts], + quant_mode=-1, + row_idx_type=0) + return global_input_tokens, expanded_scale + + # Handle non-quantized case + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, + self.global_input_tokens_local_experts_indices) + return global_input_tokens, None + + def _combine_preprocess(self, hidden_states): + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, self.reversed_global_input_permutation_mapping) + + return hidden_states + + def _combine_postprocess(self, permutated_local_input_tokens): + # Unpermutation 1: AlltoAll output to output + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=permutated_local_input_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping.to( + torch.int32), + probs=self.topk_weights, + restore_shape=self.hidden_shape_before_permute) + + # Reshape the output tensor + output = output.view(self.hidden_shape) + return output diff --git a/vllm_npu/ops/register_custom_ops.py b/vllm_npu/ops/register_custom_ops.py new file mode 100644 index 0000000..c02f03e --- /dev/null +++ b/vllm_npu/ops/register_custom_ops.py @@ -0,0 +1,315 @@ +import torch +import torch.nn.functional as F +import torch_npu +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context +from vllm.utils import direct_register_custom_op + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_forward_context import MoECommType +from vllm_npu.ops.weight_prefetch import maybe_npu_prefetch +from vllm_npu.utils import npu_stream_switch, prefetch_stream + + +def _maybe_all_gather_and_maybe_unpad_impl( + x: torch.Tensor, + label: bool, + is_ep_comm: bool = False) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return x + + sp_enabled = forward_context.sp_enabled + if sp_enabled and label: + dp_metadata = forward_context.dp_metadata + if dp_metadata is None or not is_ep_comm: + x = tensor_model_parallel_all_gather(x, 0) + pad_size = forward_context.pad_size + if pad_size > 0: + x = x[:-pad_size, :] + else: + x = get_ep_group().all_gather(x, 0) + # unpad + num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu + result = torch.empty( + (num_tokens_across_dp_cpu.sum(), *x.shape[1:]), + device=x.device, + dtype=x.dtype) + dp_size = get_dp_group().world_size + x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) + offset = 0 + for idx in range(dp_size): + num_tokens_dp = num_tokens_across_dp_cpu[idx] + result[offset:offset + + num_tokens_dp, :] = x[idx, :num_tokens_dp, :] + offset += num_tokens_dp + x = result + + return x + + +def _maybe_pad_and_reduce_impl(x: torch.Tensor, + is_ep_comm: bool = False) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return tensor_model_parallel_all_reduce(x) + + if not forward_context.sp_enabled: + return tensor_model_parallel_all_reduce(x) + + dp_metadata = forward_context.dp_metadata + if dp_metadata is None or not is_ep_comm: + pad_size = forward_context.pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + return tensor_model_parallel_reduce_scatter(x, 0) + else: + # padding + dp_size = get_dp_group().world_size + num_tokens_across_dp_cpu = \ + get_forward_context().dp_metadata.num_tokens_across_dp_cpu + padded_x = torch.empty( + (dp_size, forward_context.padded_length, *x.shape[1:]), + device=x.device, + dtype=x.dtype) + offset = 0 + for idx in range(dp_size): + num_tokens_dp = num_tokens_across_dp_cpu[idx] + padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp] + offset += num_tokens_dp + + return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), + 0) + + +def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, + prefix: str) -> None: + try: + forward_context = get_forward_context() + except AssertionError: + return + + if not forward_context.prefetch_mlp_enabled: + return + model_instance = forward_context.model_instance + prefetch_stream = forward_context.prefetch_stream + layer_idx = int(prefix.split('.')[2]) + + # start point of gate_up_proj weight prefetch + if prefix.split('.')[-2] == "self_attn": + forward_context.prefetch_mlp_gate_up_proj = True + if forward_context.prefetch_mlp_gate_up_proj: + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + mlp_gate_up_prefetch_size = envs_ascend.vllm_npu_MLP_GATE_UP_PREFETCH_SIZE + torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \ + x_dependency, mlp_gate_up_prefetch_size) + return + + +def _maybe_all_gather_and_maybe_unpad_fake( + x: torch.Tensor, + label: bool, + is_ep_comm: bool = False) -> torch.Tensor: + + if get_forward_context().sp_enabled and label: + return torch.empty( + (x.shape[0] * get_tensor_model_parallel_world_size(), + *x.shape[1:]), + device=x.device, + dtype=x.dtype) + + return x + + +def _maybe_pad_and_reduce_fake(x: torch.Tensor, + is_ep_comm: bool = False) -> torch.Tensor: + if get_forward_context().sp_enabled: + return torch.empty( + (x.shape[0] // get_tensor_model_parallel_world_size(), + *x.shape[1:]), + device=x.device, + dtype=x.dtype) + + return x + + +def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, + prefix: str) -> None: + return + + +def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: + try: + forward_context = get_forward_context() + except AssertionError: + return + + if not forward_context.prefetch_mlp_enabled: + return + forward_context.prefetch_mlp_down_proj = True + model_instance = forward_context.model_instance + prefetch_stream = forward_context.prefetch_stream + layer_idx = forward_context.layer_idx + + # start point of down_proj weight prefetch + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + mlp_down_prefetch_size = envs_ascend.vllm_npu_MLP_DOWN_PREFETCH_SIZE + torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \ + x_dependency, mlp_down_prefetch_size) + forward_context.layer_idx += 1 + return + + +def _maybe_prefetch_mlp_down_proj_impl_fake( + x_dependency: torch.Tensor) -> None: + return + + +def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: + try: + forward_context = get_forward_context() + except AssertionError: + return + + if not forward_context.prefetch_mlp_enabled: + return + if forward_context.prefetch_mlp_gate_up_proj or \ + forward_context.prefetch_mlp_down_proj: + prefetch_stream = forward_context.prefetch_stream + # wait until prefetch done + torch.npu.current_stream().wait_stream(prefetch_stream) + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + return + + +def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: + return + + +def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor, + max_weight_size: int) -> None: + calculation_stream = torch_npu.npu.current_stream() + weight_prefetch_stream = prefetch_stream() + weight_prefetch_stream.wait_stream(calculation_stream) + with npu_stream_switch(weight_prefetch_stream): + maybe_npu_prefetch(inputs=weight, + dependency=start_flag, + max_size=max_weight_size) + + +def _prefetch_preprocess_impl_fake(weight: torch.Tensor, + start_flag: torch.Tensor, + max_weight_size: int) -> None: + return + + +def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None: + calculation_stream = torch_npu.npu.current_stream() + weight_prefetch_stream = prefetch_stream() + calculation_stream.wait_stream(weight_prefetch_stream) + + +def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None: + return + + +def _maybe_all_reduce_tensor_model_parallel_impl( + final_hidden_states: torch.Tensor) -> torch.Tensor: + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2 + } or forward_context.sp_enabled: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + + +def _matmul_and_reduce_impl(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.custom_op is not None + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.custom_op.matmul_and_reduce(input_parallel, bias_) + + return output + + +def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + num_tokens = input_parallel.size(0) + if forward_context.sp_enabled: + num_tokens = num_tokens // self.tp_size + output = torch.empty(size=(num_tokens, self.output_size_per_partition), + device=input_parallel.device, + dtype=input_parallel.dtype) + + return output + + +direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", + op_func=_maybe_all_gather_and_maybe_unpad_impl, + fake_impl=_maybe_all_gather_and_maybe_unpad_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_pad_and_reduce", + op_func=_maybe_pad_and_reduce_impl, + fake_impl=_maybe_pad_and_reduce_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj", + op_func=_maybe_prefetch_mlp_gate_up_proj_impl, + fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj", + op_func=_maybe_prefetch_mlp_down_proj_impl, + fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_wait_prefetch_done", + op_func=_maybe_wait_prefetch_done_impl, + fake_impl=_maybe_wait_prefetch_done_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="prefetch_preprocess", + op_func=_prefetch_preprocess_impl, + fake_impl=_prefetch_preprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="prefetch_postprocess", + op_func=_prefetch_postprocess_impl, + fake_impl=_prefetch_postprocess_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="matmul_and_reduce", + op_func=_matmul_and_reduce_impl, + fake_impl=_matmul_and_reduce_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_npu/ops/rotary_embedding.py b/vllm_npu/ops/rotary_embedding.py index b5797f3..43520cf 100644 --- a/vllm_npu/ops/rotary_embedding.py +++ b/vllm_npu/ops/rotary_embedding.py @@ -1,22 +1,129 @@ -""" -NPU-optimized rotary embedding for Ascend. - -Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass -with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route -to the NPU fused kernel automatically. -""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import math from typing import Optional, Tuple import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +import torch_npu +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, + YaRNScalingRotaryEmbedding) + +from vllm_npu.platform import NPUPlatform +from vllm_npu.utils import enable_custom_op, is_310p + + +def _custom_rotary_embedding_enabled(query, neox_style, head_size): + return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( + ) + + +def _rope_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + is_neox_style: bool, + offsets: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + query_shape, key_shape = query.shape, key.shape + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + # adopt custom kernel path for rotary_embedding + if _custom_rotary_embedding_enabled(query, is_neox_style, + self.head_size) and not is_310p(): + query, key = torch.ops._C_ascend.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + return query.view(query_shape), key.view(key_shape) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + if self.cos is not None and \ + self.sin is not None: + # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. + # This method requires head_size and rotary_dim equal 128 and neox_style is True + query = query.contiguous().view(1, query.shape[0], -1, + self.head_size) + key = key.contiguous().view(1, key.shape[0], -1, self.head_size) + torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) + elif self.rotary_dim < self.head_size: + num_tokens = query.shape[0] + query = query.view(num_tokens, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) + q_rot = query[..., :self.rotary_dim] + q_pass = query[..., self.rotary_dim:] + k_rot = key[..., :self.rotary_dim] + k_pass = key[..., self.rotary_dim:] + q_rot = q_rot.contiguous().view(num_tokens, -1) + k_rot = k_rot.contiguous().view(num_tokens, -1) + torch_npu._npu_rotary_embedding( + positions, + q_rot, + k_rot, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) + k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) + q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + return q, k + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + is_neox_style, + ) + return query.view(query_shape), key.view(key_shape) class AscendRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding using Ascend NPU fused kernel. - Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application. - """ + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + self.cos = None + self.sin = None + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) def forward_oot( self, @@ -24,54 +131,301 @@ class AscendRotaryEmbedding(RotaryEmbedding): query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - import torch_npu # noqa: F401 + is_neox_style_override: Optional[bool] = None, + ): + is_neox_style = self.is_neox_style + if is_neox_style_override is not None: + is_neox_style = is_neox_style_override + forward_context = get_forward_context() + is_first_layer = forward_context.is_first_layer + # Generate cos and sin outside layers to avoid repeated calculation. + if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ + -1] == 128: + if is_first_layer: + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( + 1, 1, 2).chunk(2, dim=-2) + # BSNH + self.cos = cos.view(1, -1, 1, last_dim).contiguous() + self.sin = sin.view(1, -1, 1, last_dim).contiguous() + forward_context.is_first_layer = False + return _rope_forward_oot(self, positions, query, key, is_neox_style, + offsets) - query_shape, key_shape = query.shape, key.shape - if self.cos_sin_cache.device != query.device: - self.cos_sin_cache = self.cos_sin_cache.to(query.device) - if self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) +class AscendYaRNRotaryEmbedding(YaRNScalingRotaryEmbedding): - if offsets is not None: - raise NotImplementedError( - "Batched rotary embedding is currently not supported on NPU." - ) + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.cos = None + self.sin = None + extra_kwargs = { + "extrapolation_factor": extrapolation_factor, + "attn_factor": attn_factor, + "beta_fast": beta_fast, + "beta_slow": beta_slow + } + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) - if self.rotary_dim < self.head_size: - # Partial rotary embedding: only rotate first rotary_dim dims - num_tokens = query.shape[0] - query = query.view(num_tokens, -1, self.head_size) - key = key.view(num_tokens, -1, self.head_size) + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + ): + return AscendRotaryEmbedding.forward_oot(self, positions, query, key, + offsets, + is_neox_style_override) - q_rot = query[..., :self.rotary_dim] - q_pass = query[..., self.rotary_dim:] - k_rot = key[..., :self.rotary_dim] - k_pass = key[..., self.rotary_dim:] - q_rot = q_rot.contiguous().view(num_tokens, -1) - k_rot = k_rot.contiguous().view(num_tokens, -1) +class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): - torch_npu._npu_rotary_embedding( - positions, q_rot, k_rot, - self.head_size, self.cos_sin_cache, self.is_neox_style, - ) + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + # Note: we adopt the native huggingface deepseek rope initialization code from + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for + # its more ascend compute friendly + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + self._yarn_get_mscale(self.scaling_factor, float(mscale)) / + self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, + base, is_neox_style, dtype) - q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) - k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) - q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) - k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) - return q, k - else: - # Full rotary embedding - # TODO: Remove the contiguous in the future. - query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(key.shape[0], -1) + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) + self._set_cos_sin_cache(self.max_seq_len, + device=NPUPlatform.device_type, + dtype=dtype) - torch_npu._npu_rotary_embedding( - positions, query, key, - self.head_size, self.cos_sin_cache, self.is_neox_style, - ) + def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 - return query.view(query_shape), key.view(key_shape) + def _rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _yarn_linear_ramp_mask(self, min_value, max_value, dim): + # Note: The if conditional branch is not used here + # to solve MTP compilation error. + max_value += (min_value == max_value).float() * 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - + min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Inverse dim formula to find dim based on number of rotations + def _yarn_find_correction_dim(self, + num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + return (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * + torch.log(torch.tensor(base))) + + # Find dim range bounds based on rotations + def _yarn_find_correction_range(self, + low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + low = torch.floor( + self._yarn_find_correction_dim(low_rot, dim, base, + max_position_embeddings)) + high = torch.ceil( + self._yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + # Note: use torch instead of max/min to solve MTP compilation error. + return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) + + # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb + def _apply_rotary_pos_emb(self, + q, + k, + cos, + sin, + position_ids, + unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids] + sin = sin[position_ids] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + if len(q.shape) == 3: + q = q[:, :, None, :] + if len(k.shape) == 2: + k = k[:, None, None, :] + elif len(k.shape) == 3: + k = k[:, :, None, :] + + b, h_q, s, d = q.shape + q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) + + b, h_k, s, d = k.shape + k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) + + q_embed = (q * cos) + (self._rotate_half(q) * sin) + k_embed = (k * cos) + (self._rotate_half(k) * sin) + + q_embed = q_embed.view(b, h_q, d) + k_embed = k_embed.view(b, h_k, d) + + return q_embed, k_embed + + def _set_cos_sin_cache(self, max_seq_len, device, dtype): + dim = self.rotary_dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = self._yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask( + low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - + inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale + sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale + cos_cached = cos_cached.to(dtype) + sin_cached = sin_cached.to(dtype) + cache = torch.cat( + [freqs.cos() * self.mscale, + freqs.sin() * self.mscale], dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + def forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None): + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + is_neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, + 2).transpose(3, 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, + 2).reshape(b, h_k, d) + q_pe, k_pe = _rope_forward_oot(self, positions, query, key, + is_neox_style, offsets) + return q_pe, k_pe + + +class AscendMRotaryEmbedding(MRotaryEmbedding): + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ): + if self.mrope_section != [16, 24, 24]: + return super().forward_oot(positions, query, key) + + import torch_npu + mrope_section = [0, 0, 0 + ] if positions.ndim == 1 else self.mrope_section + + if self.cos_sin_cache.device != query.device: # type: ignore + self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore + query.device) # type: ignore + + if self.cos_sin_cache.dtype != query.dtype: # type: ignore + self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore + query.dtype) # type: ignore + + query, key = torch_npu.npu_mrope(positions.contiguous(), + query.contiguous(), + key.contiguous(), + self.cos_sin_cache.contiguous(), + self.head_size, + mrope_section=mrope_section, + rotary_mode='half') + + return query, key \ No newline at end of file diff --git a/vllm_npu/ops/sigmoid_gating.py b/vllm_npu/ops/sigmoid_gating.py new file mode 100644 index 0000000..c99799c --- /dev/null +++ b/vllm_npu/ops/sigmoid_gating.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors + +import os +from typing import Optional + +import torch +from vllm.triton_utils import tl, tldevice, triton + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + + @triton.jit + def div_normal(x, y): + return x / y + + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t + p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t + p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t + else: + p_beta = beta + bos * HV + i_hv + HV * i_t + p_g = g + bos * HV + i_hv + HV * i_t + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t + + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + # b_h *= tl.exp(b_g) + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # print("N: ", N) + # print("T: ", T) + # print("B: ", B) + # print("H: ", H) + # print("HV: ", HV) + # print("K: ", K) + # print("V: ", V) + # print("BK: ", BK) + # print("BV: ", BV) + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state \ No newline at end of file diff --git a/vllm_npu/ops/vocab_parallel_embedding.py b/vllm_npu/ops/vocab_parallel_embedding.py new file mode 100644 index 0000000..b7f85b3 --- /dev/null +++ b/vllm_npu/ops/vocab_parallel_embedding.py @@ -0,0 +1,255 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn.parameter import Parameter +from vllm.distributed import divide +from vllm.distributed.parallel_state import get_tp_group +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod, + VocabParallelEmbedding, pad_vocab_size) +from vllm.model_executor.utils import set_weight_attrs + +from vllm_npu.distributed.parallel_state import get_lmhead_tp_group +from vllm_npu.utils import lmhead_tp_enable + + +class AscendVocabParallelEmbedding(VocabParallelEmbedding): + """ + Register VocabParallelEmbedding as a custom op for Ascend. + AscendVocabParallelEmbedding support different communication parallel groups + Added the feature of lmheadTP in pure dp scenario + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + nn.Module.__init__(self) + + if lmhead_tp_enable() and prefix.find("head") != -1: + self.comm_group = get_lmhead_tp_group() + else: + self.comm_group = get_tp_group() + + self.tp_size = self.comm_group.world_size + self.tp_rank = self.comm_group.rank_in_group + + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + self.tp_rank, self.tp_size) + self.embedding_dim = embedding_dim + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + def _get_masked_input_and_mask( + self, input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) + # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. + if added_vocab_start_index == added_vocab_end_index: + valid_offset = (org_vocab_start_index * org_vocab_mask) + vocab_mask = org_vocab_mask + else: + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - + org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + # Adapt end. + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = self._get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + return output + + +class AscendParallelLMHead(ParallelLMHead): + """ + Register ParallelLMHead as a custom op for Ascend.""" + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + AscendVocabParallelEmbedding.__init__(self, num_embeddings, + embedding_dim, params_dtype, + org_num_embeddings, padding_size, + quant_config, prefix) + + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + +class AscendLogitsProcessor(LogitsProcessor): + """ + Register LogitsProcessor as a custom op for Ascend. + Added the feature of lmheadTP in pure dp scenario + """ + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + if lmhead_tp_enable(): + return self._get_logits_lmheadtp(hidden_states, lm_head, + embedding_bias) + else: + return self._get_logits_normal(hidden_states, lm_head, + embedding_bias) + + def _get_logits_lmheadtp( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Gather hidden states from all devices in tensor parallel group + gathered_hidden_states = get_lmhead_tp_group().all_gather( + hidden_states, dim=0) + local_logits = lm_head.quant_method.apply(lm_head, + gathered_hidden_states, + bias=embedding_bias) + # Gather logits for tensor parallel + logits = get_lmhead_tp_group().all_to_all(local_logits) + # Remove paddings in vocab (if any) + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits + + def _get_logits_normal( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + local_logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + # Gather logits for tensor parallel + logits = self._gather_logits(local_logits) + + # Remove paddings in vocab (if any) + if logits is not None: + logits = logits[..., :self.org_vocab_size] + + return logits diff --git a/vllm_npu/ops/weight_prefetch.py b/vllm_npu/ops/weight_prefetch.py new file mode 100644 index 0000000..761591d --- /dev/null +++ b/vllm_npu/ops/weight_prefetch.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass, field + +import torch +import torch_npu +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_config import WeightPrefetchConfig +from vllm_npu.ops.linear import (AscendQKVParallelLinear, + AscendRowParallelLinear) + +SUPPORTED_MODULES = ["attn", "mlp", "moe"] +MOE_PREFETCH_TOKEN_THRESHOLD = 96 + + +@dataclass +class ModuleWeightPrefetchConfig: + module_name: str + enable: bool = False + is_active_this_forward: bool = False + prefetch_ratio: dict = field(default_factory=dict) + linear_prefix_map: dict = field(default_factory=dict) + + def __post_init__(self) -> None: + self.prefetch_ratio = { + prefix: ratio + for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 + } + + assert self.module_name in SUPPORTED_MODULES, ( + f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" + ) + + if self.module_name in SUPPORTED_MODULES: + self.enable = self.enable and any(self.prefetch_ratio.values()) > 0 + + +class WeightPrefetchMethod: + """ + Unified weight prefetch method. + """ + + def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: + self.attn = ModuleWeightPrefetchConfig( + module_name="attn", + enable=weight_prefetch_config.enabled, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "attn", {}), + linear_prefix_map={ + AscendQKVParallelLinear.__name__: "qkv", + AscendRowParallelLinear.__name__: "o", + }) + self.moe = ModuleWeightPrefetchConfig( + module_name="moe", + enable=weight_prefetch_config.enabled, + prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( + "moe", {})) + + def maybe_prefetch_attn_weight_preprocess( + self, layer_cls_name: str, weight: torch.Tensor, + start_flag: torch.Tensor) -> None: + if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: + return + + prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.attn.prefetch_ratio.get(prefix, 0) + + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=start_flag, + max_weight_size=int(weight_size)) + + def maybe_prefetch_attn_weight_postprocess( + self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: + if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: + return + + torch.ops.vllm.prefetch_postprocess(stop_flag) + + def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix): + self.moe.is_active_this_forward = hidden_states.shape[ + 0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False + if not self.moe.is_active_this_forward: + return + forward_context = get_forward_context() + # layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm. + weight = forward_context.model_instance.model.layers[ + forward_context.layer_idx - 1].mlp.experts.w13_weight + weight_size = weight.data.element_size() * weight.data.numel( + ) * self.moe.prefetch_ratio.get(prefix, 0) + torch.ops.vllm.prefetch_preprocess(weight=weight, + start_flag=None, + max_weight_size=int(weight_size)) + + def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor): + if not self.moe.is_active_this_forward: + return + + torch.ops.vllm.prefetch_postprocess(stop_flag) + + +def maybe_npu_prefetch(inputs: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + offset: int = 0, + *, + enabled: bool = True) -> None: + if not enabled: + return + input_size = inputs.element_size() * inputs.numel() + if max_size <= 0 or max_size > input_size: + max_size = input_size + torch_npu.npu_prefetch(inputs, dependency, max_size, offset) diff --git a/vllm_npu/patch/__init__.py b/vllm_npu/patch/__init__.py new file mode 100644 index 0000000..47ce6dc --- /dev/null +++ b/vllm_npu/patch/__init__.py @@ -0,0 +1,174 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ---------------------------------------------------------------------------------- +# This module manage the patch for vllm. There are two folders in this module: +# - platform: contains the patches applied before worker starts. It's called by +# `vllm_npu.utils.adapt_patch(is_global_patch=True)` in +# `vllm_npu.platform.NPUPlatform.pre_register_and_update()` function. +# - worker: contains the patches applied when worker starts. It's called by +# `vllm_npu.utils.adapt_patch(is_global_patch=False)` in +# each worker's `__init__` function. +# +# Once a new patch is added in vllm-ascend, please add the patch description into this file as well. +# ---------------------------------------------------------------------------------- + +# What's Patched and how it works: +# -------------------------------- +# * Platform Patch: +# ================= +# ** File: platform/patch_distributed.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.config.ParallelConfig.get_next_dp_init_port` +# Why: +# vllm doesn't support get port from environment. +# How: +# Add the logic to get port from environment. +# Related PR (if no, explain why): +# Need a PR to vllm to support get port from environment. +# Future Plan: +# Remove those patch when vllm merged them +# 2. `torch.distributed.all_reduce`, `torch.distributed.broadcast` +# Why: +# tensor alignment for 310p +# How: +# rewrite all_reduce and broadcast in torch.distributed +# Related PR (if no, explain why): +# No, not ready yet. +# Future Plan: +# Find a better way to support tensor alignment for 310p without this patch. +# +# ** File: worker/patch_multimodal_merge.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.utils._merge_multimodal_embeddings` +# Why: +# '_merge_multimodal_embeddings' func of vllm is incompatible with Ascend. +# How: +# Replace with CPU operation that can be executed asynchronously. +# Related PR (if no, explain why): +# This is a bug by Ascend only. It can' be fixed in vLLM. +# Future Plan: +# Identify this pattern in torch-npu and remove this patch. +# +# * Worker Patch: +# =============== +# ** File: worker/patch_minicpm.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.minicpm.MiniCPMAttention.forward` +# Why: +# The forward func of MiniCPMAttention in vllm do a datatype convert +# (original datatype --> float32) to ensure the precision on cuda. +# However float32 is not supported in cann rope op, thus we keep this patch +# How: +# Removed the dtype convert operations in forward +# Related PR (if no, explain why): +# NO, only for npu due to rope op. +# Future Plan: +# Keep this patch in vllm-ascend. +# +# ** File: worker/patch_distributed.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.distributed.parallel_state.GroupCoordinator` +# (1) __init__() +# Why: +# The original GroupCoordinator initialization lacks pg_options to generate new +# process group with customized options. +# How: +# Inject HCCL options during process group initialization. +# Related PR (if no, explain why): +# Need a PR to vllm to support a dictionary as input while initializing distributed +# environment (e.g., Dict[str, torch.distributed.ProcessGroupHCCL.Options]) +# https://github.com/vllm-project/vllm/pull/25417 +# Future Plan: +# Remove this patch when vllm merges this PR. +# (2) all_to_all() +# Why: +# vllm doesn't support all_to_all for GroupCoordinator. +# How: +# Add all_to_all implementation for GroupCoordinator. +# Related PR (if no, explain why): +# Need a PR to vllm to support all_to_all for GroupCoordinator. +# Future Plan: +# Remove this patch when vllm merged them. +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs` +# Why: +# We need to patch gather_logprobs to make sure call batched_count_greater_than +# with backend=current_platform.simple_compile_backend +# How: +# Patch gather_logprobs call new batched_count_greater_than +# Related PR (if no, explain why): +# - https://github.com/vllm-project/vllm/pull/21591 +# Future Plan: +# Revert it when vLLM merge #21591 and release new version +# ** File: worker/patch_logits.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm._custom_ops.apply_repetition_penalties` +# Why: +# apply_repetition_penalties in vLLM use tensor.is_cuda to check if tensor is on cuda. But the value is always True +# on ascend, thus we need to patch apply_repetition_penalties. +# How: +# Remove the related cuda check in apply_repetition_penalties. +# Related PR (if no, explain why): +# - this is a bug by Ascend only. It can' be fixed in vLLM. +# Future Plan: +# Fix this bug in torch-npu, bump torch-npu version and remove this patch. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.roberta.RobertaEmbedding.forward` +# Why: +# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode +# How: +# Replace shift operation with multiplication and division. +# Related PR (if no, explain why): +# No, this need CANN add an aclnn shift operation +# Future Plan: +# Revert this when CANN support shift aclnn operation +# 2. `vllm.model_executor.models.roberta.RobertaForSequenceClassification.forward ` +# Why: +# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode +# How: +# Replace shift operation with multiplication and division. +# Related PR (if no, explain why): +# No, this need CANN add an aclnn shift operation +# Future Plan: +# Revert this when CANN support shift aclnn operation +# +# ** File: worker/patch_deepseek_mtp.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.__init__` +# Why: +# '__init__' func of DeepSeekMultiTokenPredictorLayer didn't pass prefix to SharedHead. +# How: +# Replace with a new __init__. +# Use a new SharedHead which passes prefix to ParallelLMHead. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/25805 +# Future Plan: +# Remove this patch when adapted vllm version contains the above PR. +# +# ** File: worker/patch_attention_layer.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.attention.layer.Attention.forward` +# Why: +# There is a zerolike operator before the attention operation in each decoding stage. +# How +# Replace this zerolike operator with torch.empty +# Related PR (if no, explain why): +# - https://github.com/vllm-project/vllm/pull/26680 +# Future Plan: +# Remove this to match the optimization supported in the VLLM version. +# diff --git a/vllm_npu/patch/platform/__init__.py b/vllm_npu/patch/platform/__init__.py new file mode 100644 index 0000000..a096901 --- /dev/null +++ b/vllm_npu/patch/platform/__init__.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import vllm_npu.patch.platform.patch_config # noqa +import vllm_npu.patch.platform.patch_distributed # noqa +import vllm_npu.patch.platform.patch_mamba_config # noqa +import vllm_npu.patch.platform.patch_sched_yield # noqa + +if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv( + "EXPERT_MAP_RECORD", "false") == "true": + import vllm_npu.patch.platform.patch_multiproc_executor # noqa + +if os.getenv("SHM_BARRIER", "true") == "true": + import vllm_npu.patch.platform.patch_core # noqa + import vllm_npu.patch.platform.patch_message_queue # noqa diff --git a/vllm_npu/patch/platform/patch_config.py b/vllm_npu/patch/platform/patch_config.py new file mode 100644 index 0000000..d615038 --- /dev/null +++ b/vllm_npu/patch/platform/patch_config.py @@ -0,0 +1,234 @@ +import ast + +import vllm.envs as envs +from vllm.config.speculative import SpeculativeConfig +from vllm.logger import logger + + +def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + if (self.target_model_config + and self.target_model_config.hf_text_config.model_type + in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe", + "qwen3_next")): + # use the draft model from the same model: + self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided but without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + # TODO: Move this import to the top once `ModelConfig` + # lives in `vllm.config.model`. + from vllm.config import ModelConfig + self.draft_model_config = ModelConfig( + model=self.model, + runner="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + allowed_media_domains=self.target_model_config. + allowed_media_domains, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config.tokenizer_revision, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ('eagle', 'eagle3'): + pass + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type + in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == + "qwen3_next_mtp"): + self.method = "qwen3_next_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Qwen3Next MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type + in ("longcat_flash_mtp")): + self.method = "longcat_flash_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "LongCat MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + else: + self.method = "draft_model" + raise NotImplementedError( + "Speculative decoding with draft model is not " + "supported yet. Please consider using other " + "speculative decoding methods such as ngram, medusa, " + "eagle, or deepseek_mtp.") + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs import SpeculatorsConfig + from vllm.transformers_utils.configs.eagle import EAGLEConfig + + if isinstance(self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig)): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle") + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, "n_predict", + None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([ + (i + 1) * (0, ) for i in range(self.num_speculative_tokens) + ]) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval(self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t))) + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) + + +SpeculativeConfig.__post_init__ = __post_init__ diff --git a/vllm_npu/patch/platform/patch_core.py b/vllm_npu/patch/platform/patch_core.py new file mode 100644 index 0000000..56a519f --- /dev/null +++ b/vllm_npu/patch/platform/patch_core.py @@ -0,0 +1,68 @@ +import signal +from typing import Optional + +from vllm.config import ParallelConfig +from vllm.logger import logger +from vllm.transformers_utils.config import \ + maybe_register_config_serialize_by_value +from vllm.utils import decorate_logs, set_process_title +from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc + + +def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): + """Launch EngineCore busy loop in background process.""" + + from vllm.distributed.device_communicators.shm_broadcast import \ + MessageQueue # noqa + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + # Ensure we can serialize transformer config after spawning + maybe_register_config_serialize_by_value() + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + engine_core: Optional[EngineCoreProc] = None + try: + parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config + if parallel_config.data_parallel_size > 1 or dp_rank > 0: + set_process_title("EngineCore", f"DP{dp_rank}") + decorate_logs() + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, **kwargs) + else: + set_process_title("EngineCore") + decorate_logs() + engine_core = EngineCoreProc(*args, **kwargs) + + engine_core.run_busy_loop() + + except SystemExit: + logger.debug("EngineCore exiting.") + raise + except Exception as e: + if engine_core is None: + logger.exception("EngineCore failed to start.") + else: + logger.exception("EngineCore encountered a fatal error.") + engine_core._send_engine_dead() + raise e + finally: + if engine_core is not None: + engine_core.shutdown() + + +EngineCoreProc.run_engine_core = run_engine_core diff --git a/vllm_npu/patch/platform/patch_distributed.py b/vllm_npu/patch/platform/patch_distributed.py new file mode 100644 index 0000000..1408a80 --- /dev/null +++ b/vllm_npu/patch/platform/patch_distributed.py @@ -0,0 +1,115 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/models/qwen2_vl.py +# This file is a part of the vllm-ascend project. + +import torch +import vllm.envs as envs_vllm +from vllm.config import ParallelConfig + +from vllm_npu.utils import is_310p + + +def parallel_config_get_dp_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + + # NOTE: Get port from envs directly when using torchrun + port = envs_vllm.VLLM_DP_MASTER_PORT if envs_vllm.VLLM_DP_MASTER_PORT else answer + return port + + +ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port + + +class NullHandle: + + def __init__(self): + pass + + def wait(self): + pass + + +def communication_adaptation_310p(): + + def broadcast310p_wrapper(fn): + + def broadcast310p(tensor, src, group=None, async_op=False): + if tensor.device == torch.device('cpu'): + return fn(tensor, src, group, async_op) + rank = torch.distributed.get_rank(group) + world_size = torch.distributed.get_world_size(group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[rank] = tensor + torch.distributed.all_gather(tensor_list, tensor, group=group) + tensor[...] = tensor_list[src] + if async_op: + return NullHandle() + else: + return None + + return broadcast310p + + torch.distributed.broadcast = broadcast310p_wrapper( + torch.distributed.broadcast) + torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper( + torch.distributed.distributed_c10d.broadcast) + + def all_reduce_wrapper_310p(fn): + + def all_reduce( + tensor, + op=torch.distributed.ReduceOp.SUM, + group=None, + async_op=False, + ): + if tensor.dtype != torch.int64: + return fn(tensor, op, group, async_op) + rank = torch.distributed.get_rank(group) + world_size = torch.distributed.get_world_size(group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[rank] = tensor + torch.distributed.all_gather(tensor_list, tensor, group=group) + if op == torch.distributed.ReduceOp.SUM: + return torch.stack(tensor_list).sum(0) + elif op == torch.distributed.ReduceOp.MAX: + return torch.tensor( + torch.stack(tensor_list).cpu().numpy().max(0), + device=tensor.device, + ) + else: + raise RuntimeError(f"not implement op {op}") + + return all_reduce + + torch.distributed.all_reduce = all_reduce_wrapper_310p( + torch.distributed.all_reduce) + torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p( + torch.distributed.distributed_c10d.all_reduce) + + +if is_310p(): + communication_adaptation_310p() diff --git a/vllm_npu/patch/platform/patch_mamba_config.py b/vllm_npu/patch/platform/patch_mamba_config.py new file mode 100644 index 0000000..1420fac --- /dev/null +++ b/vllm_npu/patch/platform/patch_mamba_config.py @@ -0,0 +1,96 @@ +# mypy: ignore-errors +import vllm.model_executor.models.config +from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.config import MambaModelConfig +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec + + +@classmethod +def verify_and_update_config(cls, vllm_config) -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + logger = init_logger(__name__) + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype).page_size_bytes + + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=model_config.max_model_len, + ).page_size_bytes + + block_alignment_bytes = 128 + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = block_alignment_bytes * cdiv( + mamba_page_size, block_alignment_bytes * attn_page_size_1_token) + + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if (cache_config.block_size is None + or cache_config.block_size < attn_block_size): + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + # compute new attention page size + attn_page_size = \ + cache_config.block_size * attn_page_size_1_token + + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if (cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size): + cache_config.mamba_page_size_padded = (attn_page_size) + mamba_padding_pct = 100 * (attn_page_size - + mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + +vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config diff --git a/vllm_npu/patch/platform/patch_message_queue.py b/vllm_npu/patch/platform/patch_message_queue.py new file mode 100644 index 0000000..7bf183c --- /dev/null +++ b/vllm_npu/patch/platform/patch_message_queue.py @@ -0,0 +1,164 @@ +import time +from contextlib import contextmanager +from typing import Optional + +import vllm.envs as envs +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue, + ShmRingBuffer, + SpinTimer) +from vllm.distributed.utils import sched_yield +from vllm.logger import logger +from vllm.utils import (get_ip, get_mp_context, get_open_port, + get_open_zmq_ipc_path, is_valid_ipv6_address) +from zmq import IPV6, XPUB, XPUB_VERBOSE, Context # type: ignore + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + + +def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[list[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, +): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) + + self.current_idx = 0 + self.writer_lock = get_mp_context().Lock() + else: + self.buffer = None # type: ignore + local_subscribe_addr = None + self.local_socket = None + self.current_idx = -1 + + remote_addr_ipv6 = False + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True + connect_ip = f"[{connect_ip}]" + socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + else: + remote_subscribe_addr = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + self._read_spin_timer = SpinTimer() + + self.handle = Handle( + local_reader_ranks=local_reader_ranks, + buffer_handle=self.buffer.handle() + if self.buffer is not None else None, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + +@contextmanager +def acquire_write(self, timeout: Optional[float] = None): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # Release the processor to other threads + sched_yield() + + # if we time out, raise an exception + elapsed = time.monotonic() - start_time + if timeout is not None and elapsed > timeout: + raise TimeoutError + + # if we wait for a long time, log a message + if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: + logger.info( + "No available shared memory broadcast block found" + " in %s seconds. This typically happens when some" + " processes are hanging or doing some" + " time-consuming work (e.g. compilation)", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + with self.writer_lock: + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks + break + + +MessageQueue.__init__ = __init__ +MessageQueue.acquire_write = acquire_write diff --git a/vllm_npu/patch/platform/patch_multiproc_executor.py b/vllm_npu/patch/platform/patch_multiproc_executor.py new file mode 100644 index 0000000..82b16fc --- /dev/null +++ b/vllm_npu/patch/platform/patch_multiproc_executor.py @@ -0,0 +1,151 @@ +import threading +import weakref +from concurrent.futures import ThreadPoolExecutor +from multiprocessing.synchronize import Lock as LockType +from typing import Optional + +import vllm.v1.executor.multiproc_executor +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.utils import (get_distributed_init_method, get_loopback_ip, + get_mp_context, get_open_port) +from vllm.v1.executor.abstract import FailureCallback +from vllm.v1.executor.multiproc_executor import ( + MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc, + set_multiprocessing_worker_envs) + + +class AscendMultiprocExecutor(MultiprocExecutor): + supports_pp: bool = True + + def _init_executor(self) -> None: + # Call self.shutdown at exit to clean up + # and ensure workers will be terminated. + self._finalizer = weakref.finalize(self, self.shutdown) + self.is_failed = False + self.shutdown_event = threading.Event() + self.failure_callback: Optional[FailureCallback] = None + self.io_thread_pool: Optional[ThreadPoolExecutor] = None + + self.world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + pp_parallel_size = self.parallel_config.pipeline_parallel_size + assert self.world_size == tensor_parallel_size * pp_parallel_size, ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" + f"_parallel_size ({pp_parallel_size}). ") + + # Set multiprocessing envs + set_multiprocessing_worker_envs() + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # get_loopback_ip() for communication. + distributed_init_method = get_distributed_init_method( + get_loopback_ip(), get_open_port()) + + # Initialize worker and set up message queues for SchedulerOutputs + # and ModelRunnerOutputs + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + self.rpc_broadcast_mq = MessageQueue(self.world_size, + self.world_size, + max_chunk_bytes=max_chunk_bytes) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + + # Create workers + context = get_mp_context() + shared_worker_lock = context.Lock() + unready_workers: list[UnreadyWorkerProcHandle] = [] + success = False + try: + for rank in range(self.world_size): + unready_workers.append( + AscendWorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, + )) + + # Workers must be created before wait_for_ready to avoid + # deadlock, since worker.init_device() does a device sync. + self.workers = WorkerProc.wait_for_ready(unready_workers) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc. + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + self.start_worker_monitor() + success = True + finally: + if not success: + # Clean up the worker procs if there was a failure. + # Close death_writers first to signal workers to exit + for uw in unready_workers: + if uw.death_writer is not None: + uw.death_writer.close() + self._ensure_worker_termination( + [uw.proc for uw in unready_workers]) + + # For pipeline parallel, we use a thread pool for asynchronous + # execute_model. + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue + # _async_aggregate_workers_output also assumes a single IO thread + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() + self.has_connector = self.vllm_config.kv_transfer_config is not None + + +class AscendWorkerProc(WorkerProc): + + @staticmethod + def make_worker_process( + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + shared_worker_lock: LockType, + ) -> UnreadyWorkerProcHandle: + context = get_mp_context() + # (reader, writer) + reader, writer = context.Pipe(duplex=False) + + # Create death pipe to detect parent process exit + death_reader, death_writer = context.Pipe(duplex=False) + + process_kwargs = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "input_shm_handle": input_shm_handle, + "ready_pipe": (reader, writer), + "death_pipe": death_reader, + "shared_worker_lock": shared_worker_lock, + } + # Run EngineCore busy loop in background process. + proc = context.Process( + target=WorkerProc.worker_main, + kwargs=process_kwargs, + name=f"VllmWorker-{rank}", + daemon=False, + ) + + proc.start() + writer.close() + # Keep death_writer open in parent - when parent exits, + # death_reader in child will get EOFError + return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) + + +vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor diff --git a/vllm_npu/patch/platform/patch_sched_yield.py b/vllm_npu/patch/platform/patch_sched_yield.py new file mode 100644 index 0000000..694b957 --- /dev/null +++ b/vllm_npu/patch/platform/patch_sched_yield.py @@ -0,0 +1,13 @@ +import sys + +import vllm.distributed.utils +from vllm.platforms import CpuArchEnum, Platform + +is_arm = (Platform.get_cpu_architecture() == CpuArchEnum.ARM) + +USE_SCHED_YIELD = ( + ((sys.version_info[:3] >= (3, 11, 1)) or + (sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)) + and not is_arm) + +vllm.distributed.utils.USE_SCHED_YIELD = USE_SCHED_YIELD diff --git a/vllm_npu/patch/worker/__init__.py b/vllm_npu/patch/worker/__init__.py new file mode 100644 index 0000000..225edbd --- /dev/null +++ b/vllm_npu/patch/worker/__init__.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + import vllm_npu.patch.worker.patch_triton + +# isort: off +import vllm_npu.patch.platform.patch_sched_yield # noqa +import vllm_npu.patch.worker.patch_distributed # noqa +import vllm_npu.patch.worker.patch_logits # noqa +import vllm_npu.patch.worker.patch_roberta # noqa +import vllm_npu.patch.worker.patch_weight_loader # noqa +import vllm_npu.patch.worker.patch_multimodal_merge # noqa +import vllm_npu.patch.worker.patch_minicpm # noqa +import vllm_npu.patch.worker.patch_deepseek_mtp # noqa +import vllm_npu.patch.worker.patch_attention_layer # noqa + +if os.getenv("SHM_BARRIER", "true") == "true": + import vllm_npu.patch.platform.patch_message_queue # noqa diff --git a/vllm_npu/patch/worker/patch_attention_layer.py b/vllm_npu/patch/worker/patch_attention_layer.py new file mode 100644 index 0000000..62638cf --- /dev/null +++ b/vllm_npu/patch/worker/patch_attention_layer.py @@ -0,0 +1,92 @@ +from typing import Optional + +import torch +import vllm +from vllm.forward_context import ForwardContext, get_forward_context + + +def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, +) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + + output_dtype = query.dtype + if self.query_quant is not None: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} + query, _ = self.query_quant(query, self._q_scale) + + if self.use_output: + output_shape = (output_shape + if output_shape is not None else query.shape) + output = torch.empty(output_shape, + dtype=output_dtype, + device=query.device) + hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, self_kv_cache, + attn_metadata) + else: + return torch.ops.vllm.unified_attention(query, key, value, + self.layer_name) + + +vllm.attention.layer.Attention.forward = forward \ No newline at end of file diff --git a/vllm_npu/patch/worker/patch_deepseek_mtp.py b/vllm_npu/patch/worker/patch_deepseek_mtp.py new file mode 100644 index 0000000..68ac359 --- /dev/null +++ b/vllm_npu/patch/worker/patch_deepseek_mtp.py @@ -0,0 +1,94 @@ +from typing import Optional + +import torch +import torch.nn as nn +import vllm +from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMTP, DeepSeekMultiTokenPredictorLayer) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer +from vllm.model_executor.models.utils import maybe_prefix + + +def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, +) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + # Patch this for aclgraph support, as the original operation introduced d2h sync, + # which breaks aclgraph + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0 +@support_torch_compile +class AscendDeepSeekMTP(DeepSeekMTP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + quant_config: QuantizationConfig = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head"), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + + # We don't need topk_indices_buffer in Ascend + topk_indices_buffer = None + self.shared_head = SharedHead(config=config, + prefix=prefix, + quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer) + + +DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init +vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward diff --git a/vllm_npu/patch/worker/patch_distributed.py b/vllm_npu/patch/worker/patch_distributed.py new file mode 100644 index 0000000..e6cfae6 --- /dev/null +++ b/vllm_npu/patch/worker/patch_distributed.py @@ -0,0 +1,115 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional, Union + +import torch +import vllm +from torch.distributed import Backend +from vllm.distributed.parallel_state import (GroupCoordinator, + _get_unique_name, _register_group) + +from vllm_npu.distributed.communicator import NPUCommunicator +from vllm_npu.utils import create_hccl_pg_options + + +class GroupCoordinatorPatch(GroupCoordinator): + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool, # whether to use device communicator + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + + self_device_group = None + self_cpu_group = None + hccl_pg_options = create_hccl_pg_options(group_name) + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, + backend=torch_distributed_backend, + pg_options=hccl_pg_options) + + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self_device_group = device_group + self_cpu_group = cpu_group + + assert self_cpu_group is not None + assert self_device_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group + self.device = torch.npu.current_device() + + self.use_device_communicator = use_device_communicator + self.device_communicator = None + if use_device_communicator and self.world_size > 1: + self.device_communicator = NPUCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + from vllm.distributed.device_communicators.shm_broadcast import \ + MessageQueue + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + self.use_custom_op_call = False + self.use_cpu_custom_send_recv = False + + def all_to_all(self, + input_: torch.Tensor, + scatter_dim: int = 0, + gather_dim: int = -1, + scatter_sizes: Optional[List[int]] = None, + gather_sizes: Optional[List[int]] = None) -> torch.Tensor: + if self.world_size == 1: + return input_ + assert -input_.dim() <= scatter_dim < input_.dim(), ( + f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}" + ) + assert -input_.dim() <= gather_dim < input_.dim(), ( + f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}" + ) + assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1" + return self.device_communicator.all_to_all(input_, scatter_dim, + gather_dim, scatter_sizes, + gather_sizes) + + +vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch diff --git a/vllm_npu/patch/worker/patch_logits.py b/vllm_npu/patch/worker/patch_logits.py new file mode 100644 index 0000000..84a92f9 --- /dev/null +++ b/vllm_npu/patch/worker/patch_logits.py @@ -0,0 +1,26 @@ +import torch +import vllm +from vllm._custom_ops import apply_repetition_penalties_torch + + +def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor) -> None: + """Apply repetition penalties to logits in-place. + + Args: + logits: The logits tensor of shape [num_seqs, vocab_size]. + prompt_mask: A boolean tensor indicating which tokens appear in the prompt. + output_mask: A boolean tensor indicating which tokens appear in the output. + repetition_penalties: The repetition penalties of shape (num_seqs, ). + """ + apply_repetition_penalties_torch(logits, prompt_mask, output_mask, + repetition_penalties) + + +# NPU device type tensors have attributes is_cuda=True and is_npu=True, according to its implementation in +# https://github.com/Ascend/pytorch/blob/863b9071cbdf47023c12c246e3efa9c6e2285fc6/torch_npu/npu/_stream_check.py#L74 +# This causes that vLLM's apply_repetition_penalties function will run into the branch of "if logits.is_cuda" and +# call the custom op implemented in CUDA, which is not compatible with NPU. +# Reference: https://github.com/vllm-project/vllm/blob/f66673a39d9f364194c249f28098cad8a5584ccb/vllm/_custom_ops.py#L314 +vllm._custom_ops.apply_repetition_penalties = apply_repetition_penalties diff --git a/vllm_npu/patch/worker/patch_minicpm.py b/vllm_npu/patch/worker/patch_minicpm.py new file mode 100644 index 0000000..663a08a --- /dev/null +++ b/vllm_npu/patch/worker/patch_minicpm.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from vllm.model_executor.models.minicpm import MiniCPMAttention + + +def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, +) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +# The type conversion in the forward function is deleted to support the rope operator. +MiniCPMAttention.forward = forward diff --git a/vllm_npu/patch/worker/patch_multimodal_merge.py b/vllm_npu/patch/worker/patch_multimodal_merge.py new file mode 100644 index 0000000..c8a1d5c --- /dev/null +++ b/vllm_npu/patch/worker/patch_multimodal_merge.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import torch +import vllm +from vllm.model_executor.models.utils import (_embedding_count_expression, + _flatten_embeddings) +from vllm.multimodal import NestedTensors + + +def _merge_multimodal_embeddings( + inputs_embeds: torch.Tensor, + is_multimodal: torch.Tensor, + multimodal_embeddings: NestedTensors, +) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in + ``input_ids``. + + Note: + This updates ``inputs_embeds`` in place. + """ + flattened = _flatten_embeddings(multimodal_embeddings) + try: + inputs_embeds[is_multimodal] = flattened + except RuntimeError as e: + num_expected_tokens = is_multimodal.sum().item() + assert isinstance(num_expected_tokens, int) + + if flattened.shape[0] != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {flattened.shape[0]} " + f"multimodal tokens to {num_expected_tokens} placeholders" + ) from e + else: + raise ValueError("Error during masked scatter operation") from e + + return inputs_embeds + + +vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings diff --git a/vllm_npu/patch/worker/patch_roberta.py b/vllm_npu/patch/worker/patch_roberta.py new file mode 100644 index 0000000..9c9f5e8 --- /dev/null +++ b/vllm_npu/patch/worker/patch_roberta.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import torch +from vllm.model_executor.models.roberta import ( + RobertaEmbedding, RobertaForSequenceClassification, + replace_roberta_positions) +from vllm.sequence import IntermediateTensors + +# aclgraph does not support shift operator for now +# TODO: revert me when aclgraph supports shift operator +TOKEN_TYPE_SHIFT = 30 +TOKEN_TYPE_MULTIPLIER = 1 << 30 +TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1 + + +def _encode_token_type_ids(input_ids: torch.Tensor, + token_type_ids: torch.Tensor) -> None: + # input_ids can be padded to the right + input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids * + TOKEN_TYPE_MULTIPLIER) + + +def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: + + token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER + + input_ids.bitwise_and_(TOKEN_MASK) + + return token_type_ids + + +def roberta_for_sequence_classification_forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, +) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + if token_type_ids is not None: + assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + return self.roberta(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + +def roberta_embedding_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, +) -> torch.Tensor: + + token_type_ids = _decode_token_type_ids(input_ids) + + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +RobertaEmbedding.forward = roberta_embedding_forward +RobertaForSequenceClassification.forward = roberta_for_sequence_classification_forward diff --git a/vllm_npu/patch/worker/patch_triton.py b/vllm_npu/patch/worker/patch_triton.py new file mode 100644 index 0000000..8ea72a4 --- /dev/null +++ b/vllm_npu/patch/worker/patch_triton.py @@ -0,0 +1,16 @@ +import vllm.model_executor.layers.fla.ops.chunk +import vllm.model_executor.layers.fla.ops.fused_recurrent +import vllm.model_executor.layers.fla.ops.layernorm_guard +import vllm.model_executor.layers.mamba.ops.causal_conv1d + +from vllm_npu.ops.casual_conv1d import (causal_conv1d_fn, + causal_conv1d_update_npu) +from vllm_npu.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule +from vllm_npu.ops.sigmoid_gating import \ + fused_recurrent_gated_delta_rule_fwd_kernel + +vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu +vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn +vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel +vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn +vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule diff --git a/vllm_npu/patch/worker/patch_weight_loader.py b/vllm_npu/patch/worker/patch_weight_loader.py new file mode 100644 index 0000000..ec3da9d --- /dev/null +++ b/vllm_npu/patch/worker/patch_weight_loader.py @@ -0,0 +1,41 @@ +import torch +from torch.nn.parameter import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import GiB_bytes + +logger = init_logger(__name__) + + +def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + try: + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + except torch.cuda.OutOfMemoryError as e: + logger.error("Failed to create unquantized linear weights: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug("Allocated: %.2f GiB", + torch.cuda.memory_allocated() / GiB_bytes) + logger.debug("Reserved: %.2f GiB", + torch.cuda.memory_reserved() / GiB_bytes) + raise RuntimeError( + "Failed to create unquantized linear weights. " + "This may be caused by insufficient memory to allocate " + "the weight.") from e + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + +UnquantizedLinearMethod.create_weights = create_weights diff --git a/vllm_npu/platform.py b/vllm_npu/platform.py index 0c980df..320e44d 100644 --- a/vllm_npu/platform.py +++ b/vllm_npu/platform.py @@ -1,9 +1,19 @@ -""" -NPUPlatform — Ascend NPU platform implementation for vLLM. - -Implements the ``vllm.platforms.Platform`` interface so that vLLM can -transparently target Huawei Ascend NPU devices. -""" +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# import gc import os @@ -11,11 +21,20 @@ from datetime import timedelta from typing import TYPE_CHECKING, Optional, Tuple import torch +import vllm.envs as envs_vllm from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import PrefixStore -from vllm.logger import init_logger +from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum +from vllm_npu.ascend_config import (check_ascend_config, get_ascend_config, + init_ascend_config) +from vllm_npu.torchair.utils import (check_torchair_cache_exist, + delete_torchair_cache_file) +from vllm_npu.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p, + is_vl_model, update_aclgraph_sizes, + update_default_aclgraph_sizes) + if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig from vllm.utils import FlexibleArgumentParser @@ -24,23 +43,42 @@ else: VllmConfig = None FlexibleArgumentParser = None -logger = init_logger(__name__) - class NPUPlatform(Platform): - """Out-of-tree platform for Huawei Ascend NPU.""" _enum = PlatformEnum.OOT device_name: str = "npu" device_type: str = "npu" - dispatch_key: str = "PrivateUse1" + simple_compile_backend: str = "eager" # Disable torch.compile() ray_device_key: str = "NPU" device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" - simple_compile_backend: str = "eager" # torch.compile not supported + dispatch_key: str = "PrivateUse1" - # ----------------------------------------------------------------- - # Device management - # ----------------------------------------------------------------- + supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD] + + def is_sleep_mode_available(self) -> bool: + return True + + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + # Adapt the global patch here. + from vllm_npu.utils import adapt_patch + adapt_patch(is_global_patch=True) + + # For online serving, "ascend" quantization method is not a choice natively, + # so we need to add "ascend" quantization method to quantization methods list + # and the user can enable quantization using "vllm serve --quantization ascend". + if parser is not None: + quant_action = parser._option_string_actions.get('--quantization') + if quant_action and hasattr(quant_action, + 'choices') and quant_action.choices: + if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: + quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) + + from vllm_npu.quantization.quant_config import \ + AscendQuantConfig # noqa: F401 @classmethod def get_device_capability(cls, device_id: int = 0): @@ -48,14 +86,11 @@ class NPUPlatform(Platform): @classmethod def get_device_name(cls, device_id: int = 0) -> str: - import torch_npu # noqa: F401 return torch.npu.get_device_name(device_id) @classmethod - def get_device_total_memory(cls, device_id: int = 0) -> int: - import torch_npu # noqa: F401 - _, total = torch.npu.mem_get_info(device_id) - return total + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True @classmethod def inference_mode(cls): @@ -63,54 +98,218 @@ class NPUPlatform(Platform): @classmethod def set_device(cls, device: torch.device): - import torch_npu # noqa: F401 torch.npu.set_device(device) @classmethod def empty_cache(cls): - import torch_npu # noqa: F401 torch.npu.empty_cache() @classmethod def synchronize(cls): - import torch_npu # noqa: F401 torch.npu.synchronize() @classmethod def mem_get_info(cls) -> Tuple[int, int]: - import torch_npu # noqa: F401 return torch.npu.mem_get_info() - @classmethod - def is_pin_memory_available(cls): - return True - - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True - - @classmethod - def get_current_memory_usage( - cls, - device: Optional[torch.types.Device] = None, - ) -> float: - import torch_npu # noqa: F401 - torch.npu.reset_peak_memory_stats(device) - return torch.npu.max_memory_allocated(device) - @classmethod def clear_npu_memory(cls): - import torch_npu # noqa: F401 gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() - def is_sleep_mode_available(self) -> bool: - return False + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + if not envs_vllm.VLLM_USE_V1: + raise ValueError("vLLM Ascend does not support V0 engine.") + # initialize ascend config from vllm additional_config + ascend_config = init_ascend_config(vllm_config) - # ----------------------------------------------------------------- - # Attention backend routing - # ----------------------------------------------------------------- + from vllm.config import CompilationLevel # noqa: E402 + compilation_config = vllm_config.compilation_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + ascend_scheduler_config = ascend_config.ascend_scheduler_config + + kv_cache_dtype = vllm_config.additional_config.get( + "kv_cache_dtype", None) + if kv_cache_dtype is not None: + vllm_config.cache_config.cache_dtype = kv_cache_dtype + elif model_config and hasattr(model_config.hf_config, "index_topk"): + vllm_config.cache_config.cache_dtype = str( + model_config.dtype).replace("torch.", "") + if model_config is None: + logger.warning("Model config is missing. This may indicate " + "that we are running a test case") + enforce_eager = False + else: + enforce_eager = getattr(model_config, "enforce_eager", False) + + check_ascend_config(vllm_config, enforce_eager) + from vllm.config.compilation import CUDAGraphMode + if enforce_eager: + logger.info("Compilation disabled, using eager mode by default") + compilation_config.level = CompilationLevel.NO_COMPILATION + + compilation_config.cudagraph_num_of_warmups = 1 + + if compilation_config.level not in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE + ]: + logger.warning( + "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", + compilation_config.level) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. + if ascend_config.torchair_graph_config.enabled: + logger.info( + "Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE" + ) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension + # mismatches or configuration inconsistencies when users reuse cached computation graphs. Though + # this will increase graph compilation duration, it significantly enhances robustness and decreases + # graph launching time during inference. + if check_torchair_cache_exist( + ) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: + logger.warning( + "Torchair cache folder is deleted here to prevent runtime issues caused by dimension " + "mismatches or configuration inconsistencies when users reuse cached computation graphs. " + "In order to decrease torchair graph compilation time, users can enable both use_cached_graph " + "and use_cached_kv_cache_bytes in torchair_graph_config.") + delete_torchair_cache_file() + + # set cudaprah sizes before extending `compilation_config.splitting_ops` + vllm_config._set_cudagraph_sizes() + # There are cases where default cudagraph_capture_sizes are not friendly + # to ascend ops && hardwares. We update these sizes here to improve + # default performance. + update_default_aclgraph_sizes(vllm_config) + # TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism + # is supported by vllm-ascend. + if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \ + enable_sp(vllm_config): + original_sizes = compilation_config.cudagraph_capture_sizes + sp_aclgraph_sizes = \ + vllm_config.update_sizes_for_sequence_parallelism(original_sizes) + assert sp_aclgraph_sizes, ( + f"cudagraph_capture_sizes {original_sizes} does not contain" + f"values that are multiples of tp_size " + f"{vllm_config.parallel_config.tensor_parallel_size}") + if len(sp_aclgraph_sizes) != len(original_sizes): + compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes + vllm_config.compilation_config.init_with_cudagraph_sizes( + sp_aclgraph_sizes) + + # TODO: Full graph is fully supported later, and the default value will be set to full graph. + if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + + if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + compilation_config.level = CompilationLevel.NO_COMPILATION + elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.info( + "PIECEWISE compilation enabled on NPU. use_inductor not supported - " + "using only ACL Graph mode") + assert compilation_config.level == CompilationLevel.PIECEWISE, \ + "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" + compilation_config.set_splitting_ops_for_v1() + compilation_config.use_inductor = False + compilation_config.splitting_ops.extend([ + "vllm.unified_ascend_attention_with_output", "vllm.mla_forward" + ]) + update_aclgraph_sizes(vllm_config) + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + logger.info( + "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " + "using only ACL Graph mode") + compilation_config.use_inductor = False + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + logger.warning(warning_message) + else: + logger.info( + "%s cudagraph_mode is not support on NPU. falling back to NONE", + compilation_config.cudagraph_mode) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.level = CompilationLevel.NO_COMPILATION + + # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 + # Then, we will have to discuss the error handling strategy and user experience + if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \ + os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1": + raise ValueError( + "ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. " + "Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you " + "need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — " + "for example, check the plog files (default: $HOME/ascend/log/debug) " + "for more information about runtime errors.") + + if parallel_config and parallel_config.worker_cls == "auto": + # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. + os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv" + if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: + parallel_config.worker_cls = "vllm_npu.torchair.torchair_worker.NPUTorchairWorker" + else: + parallel_config.worker_cls = "vllm_npu.worker.worker_v1.NPUWorker" + + if cache_config: + if cache_config.block_size is None: + cache_config.block_size = 128 + + if cache_config.enable_prefix_caching or \ + not ascend_scheduler_config.enabled or \ + getattr(ascend_scheduler_config, "enable_chunked_prefill", False): + logger.warning( + "If chunked prefill or prefix caching is enabled, block size must be set to 128." + ) + origin_block_size = cache_config.block_size + cache_config.block_size = 128 + # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. + if model_config and model_config.hf_config.model_type == "qwen3_next": + logger.warning( + "When running qwen3-next model, block_size needs to be restored to its original value." + ) + cache_config.block_size = origin_block_size + + # Activate custom ops for v1, except on 310P + if not is_310p(): + compilation_config.custom_ops = ["all"] + + # If ascend_scheduler_config is enabled, + # extents original scheduler_config to use AscendScheduler. + if ascend_config.ascend_scheduler_config.enabled: + from vllm_npu.core.schedule_config import AscendSchedulerConfig + ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config( + vllm_config.scheduler_config, + ascend_config.ascend_scheduler_config) + vllm_config.scheduler_config = ascend_scheduler_config + elif ascend_config.recompute_scheduler_enable: + from vllm_npu.core.recompute_schedule_config import \ + RecomputeSchedulerConfig + recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( + vllm_config.scheduler_config) + vllm_config.scheduler_config = recompute_scheduler_config + + if is_vl_model(vllm_config): + if bool(int(os.getenv("vllm_npu_ENABLE_FLASHCOMM", '0'))) or \ + bool(int(os.getenv("vllm_npu_ENABLE_FLASHCOMM1", '0'))): + raise ValueError( + "Currently, VL models doesn't support " + "FLASHCOMM in vllm-ascend. We will fix this in the future. " + "Please set vllm_npu_ENABLE_FLASHCOMM1=0.") @classmethod def get_attn_backend_cls( @@ -125,16 +324,68 @@ class NPUPlatform(Platform): has_sink=False, use_sparse=False, ): - return "vllm_npu.attention.attention_v1.AscendAttentionBackend" + if not use_v1: + raise ValueError("vLLM Ascend does not support V0 engine.") - # ----------------------------------------------------------------- - # Distributed - # ----------------------------------------------------------------- + ascend_config = get_ascend_config() + + if use_mla and ascend_config.enable_shared_expert_dp: + if use_mla and not use_sparse: + return "vllm_npu.torchair.torchair_mla.AscendMLATorchairBackend" + if use_mla and use_sparse: + return "vllm_npu.torchair.torchair_sfa.AscendSFATorchairBackend" + + use_torchair = ascend_config.torchair_graph_config.enabled + # choose attention backend based on use_mla and use_torchair + backend_map = { + (True, False, True): + "vllm_npu.torchair.torchair_mla.AscendMLATorchairBackend", + (True, False, False): + "vllm_npu.attention.mla_v1.AscendMLABackend", + (False, False, True): + "vllm_npu.torchair.torchair_attention.AscendAttentionTorchairBackend", + (False, False, False): + "vllm_npu.attention.attention_v1.AscendAttentionBackend", + (True, True, False): + "vllm_npu.attention.sfa_v1.AscendSFABackend", + (True, True, True): + "vllm_npu.torchair.torchair_sfa.AscendSFATorchairBackend", + } + return backend_map[(use_mla, use_sparse, use_torchair)] + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm_npu.lora.punica_npu.PunicaWrapperNPU" + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.npu.reset_peak_memory_stats(device) + return torch.npu.max_memory_allocated(device) @classmethod def get_device_communicator_cls(cls) -> str: return "vllm_npu.distributed.communicator.NPUCommunicator" + @classmethod + def is_pin_memory_available(cls): + return True + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + """Returns whether the current platform can support v1 for the supplied + model configuration. + """ + return True + + @classmethod + def get_static_graph_wrapper_cls(cls) -> str: + """ + Get piecewise backend class for piecewise graph. + """ + return "vllm_npu.compilation.acl_graph.ACLGraphWrapper" # noqa + @classmethod def stateless_init_device_torch_dist_pg( cls, @@ -144,13 +395,10 @@ class NPUPlatform(Platform): group_size: int, timeout: timedelta, ) -> ProcessGroup: - """Create an HCCL-based process group for NPU distributed.""" from torch.distributed import is_hccl_available from torch_npu._C._distributed_c10d import ProcessGroupHCCL - assert is_hccl_available(), ( - "HCCL is not available. Make sure torch_npu is properly installed." - ) + assert is_hccl_available() pg: ProcessGroup = ProcessGroup( prefix_store, @@ -161,61 +409,23 @@ class NPUPlatform(Platform): backend_options = ProcessGroupHCCL.Options() backend_options._timeout = timeout - backend_class = ProcessGroupHCCL( - prefix_store, group_rank, group_size, backend_options - ) + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, + backend_options) device = torch.device("npu") + # TODO(Yizhou): Like we mentioned above, _set_default_backend is not + # implemented in the 2.5.1 version of PyTorch. But we need to set it + # after the latest version is released. + # pg._set_default_backend(backend_type) backend_class._set_sequence_number_for_group() backend_type = ProcessGroup.BackendType.CUSTOM pg._register_backend(device, backend_type, backend_class) return pg - # ----------------------------------------------------------------- - # Configuration - # ----------------------------------------------------------------- - @classmethod - def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: - """Adapt vLLM configuration for NPU hardware.""" - from vllm.config import CompilationLevel - - # Register NPU custom ops (must happen after platform is detected) - from vllm_npu import register_npu_ops - register_npu_ops() - - parallel_config = vllm_config.parallel_config - cache_config = vllm_config.cache_config - compilation_config = vllm_config.compilation_config - - # Set worker class - if parallel_config and parallel_config.worker_cls == "auto": - parallel_config.worker_cls = ( - "vllm_npu.worker.worker_v1.NPUWorker" - ) - - # Set default block size for NPU (aligned to 128) - if cache_config and cache_config.block_size is None: - cache_config.block_size = 128 - - # Disable torch.compile on NPU — use eager mode - compilation_config.level = CompilationLevel.NO_COMPILATION - - logger.info( - "NPUPlatform: configuration updated — " - "worker_cls=%s, block_size=%s, compilation=NO_COMPILATION", - getattr(parallel_config, "worker_cls", "N/A"), - getattr(cache_config, "block_size", "N/A"), - ) - - @classmethod - def supports_v1(cls, model_config: "ModelConfig") -> bool: + def support_hybrid_kv_cache(cls) -> bool: return True - @classmethod - def support_hybrid_kv_cache(cls) -> bool: - return False - @classmethod def support_static_graph_mode(cls) -> bool: - return False + return True diff --git a/vllm_npu/quantization/__init__.py b/vllm_npu/quantization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/quantization/quant_config.py b/vllm_npu/quantization/quant_config.py new file mode 100644 index 0000000..ab10596 --- /dev/null +++ b/vllm_npu/quantization/quant_config.py @@ -0,0 +1,488 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from types import MappingProxyType +from typing import Any, Callable, Dict, List, Mapping, Optional + +import torch +from vllm.config import get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + RowParallelLinear) +from vllm.model_executor.layers.quantization import \ + register_quantization_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding) +from vllm.model_executor.parameter import PerTensorScaleParameter +from vllm.model_executor.utils import set_weight_attrs + +from vllm_npu.distributed.parallel_state import (get_mlp_tp_group, + get_otp_group) +from vllm_npu.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_npu.ops.linear import AscendUnquantizedLinearMethod +from vllm_npu.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, + oproj_tp_enable) + +from .utils import get_quant_method + + +@register_quantization_config(ASCEND_QUANTIZATION_METHOD) +class AscendQuantConfig(QuantizationConfig): + """Config class for Ascend + + This class is a general class that parse quantization configs + that are supported on ascend hardware. + """ + + def __init__(self, quant_config: Dict[str, Any]): + super().__init__() + self.quant_description = quant_config + # TODO(whx): remove this adaptation after adding "shared_head" + # to prefix of DeepSeekShareHead in vLLM. + extra_quant_dict = {} + for k in self.quant_description.keys(): + if "shared_head" in k: + new_k = k.replace(".shared_head.", ".") + extra_quant_dict[new_k] = self.quant_description[k] + self.quant_description.update(extra_quant_dict) + + def __repr__(self) -> str: + return "AscendQuantConfig:\n" + super().__repr__() + + @classmethod + def get_name(cls) -> str: + return ASCEND_QUANTIZATION_METHOD + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quant_model_description.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": + return cls(config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + if torch.npu.is_available(): + return ASCEND_QUANTIZATION_METHOD + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + vllm_config = get_current_vllm_config() + model_type = vllm_config.model_config.hf_config.model_type + if model_type in packed_modules_model_mapping: + self.packed_modules_mapping = packed_modules_model_mapping[ + model_type] + from vllm.attention.layer import Attention + if prefix.startswith("language_model"): + prefix = prefix.split('.', 1)[-1] + if isinstance(layer, LinearBase): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return AscendUnquantizedLinearMethod() + return AscendLinearMethod(self, prefix, + self.packed_modules_mapping) + elif isinstance(layer, Attention) and \ + 'fa_quant_type' in self.quant_description.keys() and \ + self.quant_description['fa_quant_type'] is not None: + return AscendKVCacheMethod(self, prefix) + elif isinstance(layer, Attention) and self.quant_description.get( + 'kv_quant_type') == 'C8': + return AscendKVCacheMethod(self, prefix) + elif isinstance(layer, FusedMoE): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return AscendUnquantizedFusedMoEMethod(layer.moe_config) + return AscendFusedMoEMethod(self, prefix, + self.packed_modules_mapping) + elif isinstance(layer, VocabParallelEmbedding): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedEmbeddingMethod() + return AscendEmbeddingMethod(self, prefix, + self.packed_modules_mapping) + return None + + def is_layer_skipped_ascend( + self, + prefix: str, + fused_mapping: Mapping[str, List[str]] = MappingProxyType({})): + # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped + proj_name = prefix.split(".")[-1] + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = self.quant_description[shard_prefix + + '.weight'] == "FLOAT" + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + elif "experts" in prefix: + # For the experts' prefix (e.g., "model.layers.3.mlp.experts") + # Assume all experts within the same MLP use the same quantization method + experts_quant_description = [ + self.quant_description[layer] + for layer in self.quant_description if prefix in layer + ] + is_skipped = any(quantization == "FLOAT" + for quantization in experts_quant_description) + else: + is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" + + assert is_skipped is not None + return is_skipped + + def get_scaled_act_names(self) -> List[str]: + return [] + + +packed_modules_model_mapping = { + "qwen3_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + }, + "deepseek_v2": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "deepseek_v3": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "kimi_k2": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "deepseek_v32": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + "deepseek_mtp": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "qwen3_next": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj": ["in_proj_qkvz", "in_proj_ba"], + }, + "qwen2_5_vl": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + }, + "glm4_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, +} + + +class AscendLinearMethod(LinearMethodBase): + """Linear method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]) -> None: + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "linear", + packed_modules_mapping) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weight_dict = self.quant_method.get_weight(input_size_per_partition, + output_size_per_partition, + params_dtype) + + # Extract packing information (if present) + packed_dim = weight_dict.pop("_packed_dim", None) + packed_factor = weight_dict.pop("_packed_factor", None) + + for weight_name, weight_param in weight_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + + # Set packing attributes if the weight is packed + if packed_dim is not None and packed_factor is not None: + set_weight_attrs(param, { + "packed_dim": packed_dim, + "packed_factor": packed_factor + }) + + layer.register_parameter(weight_name, param) + set_weight_attrs(param, extra_weight_attrs) + + pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) + for pertensor_name, pertensor_param in pertensor_dict.items(): + param = PerTensorScaleParameter(data=pertensor_param, + weight_loader=weight_loader) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) + param.weight_loader = extra_weight_attrs.get("weight_loader") + + perchannel_dict = self.quant_method.get_perchannel_param( + output_size_per_partition, params_dtype) + for perchannel_name, perchannel_param in perchannel_dict.items(): + param = torch.nn.Parameter(perchannel_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(perchannel_name, param) + set_weight_attrs(param, extra_weight_attrs) + + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj scale_bias shape is [output_size, 16], + # others are [output_size, 1] + layer_type = "row" if isinstance(layer, + RowParallelLinear) else "others" + + pergroup_dict = self.quant_method.get_pergroup_param( + input_size_per_partition, + output_size_per_partition, + params_dtype, + layer_type=layer_type) + for pergroup_name, pergroup_param in pergroup_dict.items(): + param = torch.nn.Parameter(pergroup_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(pergroup_name, param) + set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name: + setattr(param, "input_dim", 1) + param.input_dim = 1 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(layer, RowParallelLinear): + if layer.prefix.find("o_proj") != -1 and oproj_tp_enable(): + tp_rank = get_otp_group().rank_in_group + elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): + tp_rank = get_mlp_tp_group().rank_in_group + else: + tp_rank = get_tensor_model_parallel_rank() + else: + tp_rank = 0 + return self.quant_method.apply(layer, x, bias, tp_rank) + + +class AscendKVCacheMethod(BaseKVCacheMethod): + """KVCache method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "attention") + + def create_weights(self, layer: torch.nn.Module) -> None: + # Different from linear method, there are no weight processing/slicing + # steps for attention in vllm. So the whole process of create weights + # is hidden into the specific quant method. + self.quant_method.create_weights(layer) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def apply(self, layer: torch.nn.Module, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, + attn_type, scale, output) -> torch.Tensor: + return self.quant_method.apply(layer, query, key, value, kv_cache, + attn_metadata, attn_type, scale, output) + + +class AscendFusedMoEMethod(FusedMoEMethodBase): + """FusedMoE method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]): + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "moe", + packed_modules_mapping) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + weight_param = self.quant_method.get_weight( + num_experts, intermediate_size_per_partition, hidden_size, + params_dtype) + for param_key, param_value in weight_param.items(): + param = torch.nn.Parameter(param_value, requires_grad=False) + layer.register_parameter(param_key, param) + set_weight_attrs(param, extra_weight_attrs) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + per_group_param = [ + "weight_scale_second", "weight_offset_second", "scale_bias" + ] + dynamic_quant_param = self.quant_method.get_dynamic_quant_param( + num_experts, intermediate_size_per_partition, hidden_size, + params_dtype) + for param_key, param_value in dynamic_quant_param.items(): + param = torch.nn.Parameter(param_value, requires_grad=False) + layer.register_parameter(param_key, param) + set_weight_attrs(param, extra_weight_attrs) + if any(fields in param_key for fields in per_group_param): + setattr(param, "quant_method", + FusedMoeWeightScaleSupported.GROUP.value) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + global_redundant_expert_num=0, + **kwargs, + ) -> torch.Tensor: + return self.quant_method.apply( + layer, x, router_logits, top_k, renormalize, use_grouped_topk, + global_num_experts, expert_map, topk_group, num_expert_group, + custom_routing_function, scoring_func, e_score_correction_bias, + is_prefill, enable_force_load_balance, log2phy, + global_redundant_expert_num, **kwargs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def get_fused_moe_quant_config(self, layer: torch.nn.Module): + # TODO: implement this function + pass + + +class AscendEmbeddingMethod(AscendLinearMethod): + """Embedding method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]) -> None: + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "linear", + packed_modules_mapping) diff --git a/vllm_npu/quantization/utils.py b/vllm_npu/quantization/utils.py new file mode 100644 index 0000000..0fb156a --- /dev/null +++ b/vllm_npu/quantization/utils.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, Optional, Type + +from vllm.logger import logger + +from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod +from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) +from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, + AscendW8A8LinearMethod) +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) + +ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { + "W4A8_DYNAMIC": { + "linear": AscendW4A8DynamicLinearMethod, + "moe": AscendW4A8DynamicFusedMoEMethod, + }, + "W4A4_FLATQUANT_DYNAMIC": { + "linear": AscendW4A4FlatQuantDynamicLinearMethod, + }, + "W8A8": { + "linear": AscendW8A8LinearMethod, + "moe": AscendW8A8FusedMoEMethod, + "attention": AscendC8KVCacheMethod, + }, + "W8A8_DYNAMIC": { + "linear": AscendW8A8DynamicLinearMethod, + "moe": AscendW8A8DynamicFusedMoEMethod, + }, + "C8": { + "attention": AscendC8KVCacheMethod, + }, +} + + +def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, + packed_modules_mapping: Dict[str, Any]): + proj_name = prefix.split(".")[-1] + if proj_name in packed_modules_mapping: + quant_type = None + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in packed_modules_mapping[proj_name] + ] + for shard_prefix in shard_prefixes: + shard_quant_type = quant_description[shard_prefix + '.weight'] + + if quant_type is None: + quant_type = shard_quant_type + elif shard_quant_type != quant_type: + raise ValueError( + f"Not all shards of {prefix} are quantized with same quant type." + f"Shard {proj_name} uses {shard_quant_type}, but another shard" + f"use {quant_type}. Please check quantization config.") + elif "experts" in prefix: + # For the experts' prefix (e.g., "model.layers.3.mlp.experts") + # Assume all experts within the same MLP use the same quantization method + experts_quant_description = set(quant_description[layer] + for layer in quant_description + if prefix in layer) + if not len(experts_quant_description) == 1: + raise RuntimeError( + f"{prefix} has different quantization type: {experts_quant_description}." + ) + quant_type = experts_quant_description.pop() + else: + quant_type = quant_description[prefix + '.weight'] + return quant_type + + +def get_quant_method(quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + logger.info_once("Using the vLLM Ascend Quantization now!") + if packed_modules_mapping is None: + packed_modules_mapping = dict() + # Attention + if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): + quant_type = quant_description['fa_quant_type'] + # Use KVCache int8 + elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): + quant_type = quant_description['kv_quant_type'] + # Linear + else: + quant_type = get_linear_quant_type(quant_description, prefix, + packed_modules_mapping) + if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys(): + method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type] + if layer_type in method_map.keys(): + method_cls = method_map[layer_type] + return method_cls() + else: + raise NotImplementedError( + f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." + ) + raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ + f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") diff --git a/vllm_npu/quantization/w4a4_flatquant_dynamic.py b/vllm_npu/quantization/w4a4_flatquant_dynamic.py new file mode 100644 index 0000000..326980f --- /dev/null +++ b/vllm_npu/quantization/w4a4_flatquant_dynamic.py @@ -0,0 +1,193 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch_npu + +KRONECKER_QUANT_MAX_BATCH_SIZE = 32768 + + +def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor: + original_device = weight_tensor.device + weight_tensor_npu = weight_tensor.npu() + weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( + weight_tensor_npu.to(torch.int32), inner_k_tiles=1) + return weight_int4_packed.to(original_device) + + +def get_decompose_dim(n): + a = int(math.sqrt(n)) + if a * a < n: + a += 1 + while True: + tmp = a * a - n + b = int(math.sqrt(tmp)) + if b * b == tmp: + break + a += 1 + return a - b, a + b + + +# TODO: This function is a temporary workaround for the npu_kronecker_quant operator, +# which has a limitation on the maximum batch size (dim0). This wrapper should be +# removed once the operator supports larger inputs natively. +def batched_kronecker_quant( + x: torch.Tensor, + left_trans: torch.Tensor, + right_trans: torch.Tensor, + clip_ratio: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_tokens = x.shape[0] + if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: + return torch_npu.npu_kronecker_quant(x, + left_trans, + right_trans, + clip_ratio=clip_ratio, + dst_dtype=torch.int32) + x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0) + processed_chunks = [ + torch_npu.npu_kronecker_quant(chunk, + left_trans, + right_trans, + clip_ratio=clip_ratio, + dst_dtype=torch.int32) + for chunk in x_chunks + ] + quantized_list, scale_list = zip(*processed_chunks) + x_quantized_int4 = torch.cat(quantized_list, dim=0) + activation_scale = torch.cat(scale_list, dim=0) + return x_quantized_int4, activation_scale + + +class AscendW4A4FlatQuantDynamicLinearMethod: + """Linear method for Ascend W4A4_FLATQUANT_DYNAMIC. + + This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization. + - Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading + - Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for distribution smoothing + - Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights + """ + input_size = 0 + + def __init__(self): + self.transpose_weight = False + self.sym = True + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size % 8 != 0: + raise ValueError( + f"input_size ({input_size}) must be divisible by 8 for int4 packing" + ) + AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + left_trans_dim, right_trans_dim = get_decompose_dim( + AscendW4A4FlatQuantDynamicLinearMethod.input_size) + params_dict["left_trans"] = torch.empty(left_trans_dim, + left_trans_dim, + dtype=params_dtype) + params_dict["right_trans"] = torch.empty(right_trans_dim, + right_trans_dim, + dtype=params_dtype) + params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32) + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=torch.float32) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=torch.float32) + return params_dict + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + original_dtype = x.dtype + input_shape = x.shape + in_features = input_shape[-1] + left_dim = layer.left_trans.shape[0] + right_dim = layer.right_trans.shape[0] + if left_dim * right_dim != in_features: + raise ValueError( + f"FlatQuant transform matrices dimension mismatch: " + f"left_dim({left_dim}) * right_dim({right_dim}) != in_features({in_features})" + ) + left_trans_matched = layer.left_trans.to(original_dtype) + right_trans_matched = layer.right_trans.to(original_dtype) + x_reshaped = x.view(-1, left_dim, right_dim) + x_quantized_int4, activation_scale = batched_kronecker_quant( + x_reshaped, left_trans_matched, right_trans_matched, + layer.aclnn_clip_ratio) + x_quantized_reshaped = x_quantized_int4.view(-1, + left_dim * right_dim // 8) + pertoken_scale = activation_scale.view(-1).to(torch.float32) + output = torch_npu.npu_quant_matmul(x_quantized_reshaped, + layer.weight_packed.t(), + layer.weight_scale.view(-1).to( + torch.float32), + pertoken_scale=pertoken_scale, + bias=None, + output_dtype=original_dtype) + output = output.view(*input_shape[:-1], -1) + if bias is not None: + output = output + bias.to(original_dtype) + return output + + def process_weights_after_loading(self, layer): + weight_packed = pack_int4_weights(layer.weight.data) + if self.transpose_weight: + weight_packed = weight_packed.transpose(0, 1).contiguous() + layer.register_parameter( + 'weight_packed', + torch.nn.Parameter(weight_packed, requires_grad=False)) + del layer.weight + layer.weight_scale.data = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.to(torch.float32) + layer.left_trans = torch.nn.Parameter( + layer.left_trans.data.t().contiguous()) + layer.right_trans = torch.nn.Parameter(layer.right_trans.data) + layer.clip_ratio = torch.nn.Parameter( + layer.clip_ratio.data.to(torch.float32)) + layer.aclnn_clip_ratio = layer.clip_ratio.item() diff --git a/vllm_npu/quantization/w4a8_dynamic.py b/vllm_npu/quantization/w4a8_dynamic.py new file mode 100644 index 0000000..070e42c --- /dev/null +++ b/vllm_npu/quantization/w4a8_dynamic.py @@ -0,0 +1,490 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import numpy as np +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.ops.moe.experts_selector import select_experts +from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ + + +class AscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + self.new_quant_version = quant_version == "1.0.0" + + from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() + + def get_weight(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Create weight parameters. + + For new quantization version (double int4 pack into int8), the output dimension + is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned + dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader. + """ + params_dict = {} + + if self.new_quant_version: + # double int4 pack into int8: output dimension is compressed + pack_factor = 2 + actual_output_size = output_size // pack_factor + params_dict["weight"] = torch.empty(actual_output_size, + input_size, + dtype=torch.int8) + # Add packing information for vLLM's weight_loader + params_dict["_packed_dim"] = 0 + params_dict["_packed_factor"] = pack_factor + else: + params_dict["weight"] = torch.empty(output_size, + input_size, + dtype=torch.int8) + + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + """ + Create per-group quantization parameters. + """ + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16], + # others are [output_size, 1] + if self.new_quant_version: + scale_bias_dim = 16 if layer_type == "row" else 1 + + params_dict["scale_bias"] = torch.empty(output_size, + scale_bias_dim, + dtype=torch.float32) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False): + """ + Process the scale for second-level quantization. + + Args: + weight: weight tensor [k, n] (in new version, n is already compressed to n/2) + scale: first-level quantization scale [output_size] + per_group_scale: second-level per-group quantization scale [group_num, n_scale] + is_new_quant: whether it's the new quantization version (weight already compressed) + + Returns: + (antiquant_scale, bias): dequantization scale and bias (bias=None for new version) + """ + k, n = weight.shape + group_num, n_scale = per_group_scale.shape + + if is_new_quant: + # Restore logical dimension for compressed weight + n = n * 2 + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + # NOTE: scale_bias is not used currently + # because in msmodelslim w4a8 uses symmetric quantization + + # TODO: support potential future asymmetric quantization + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, + ) + + if self.new_quant_version: + # Process the loaded data based on layer type + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + else: + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + + # Convert to NPU-specific int4pack format + if self.new_quant_version: + # weights on disk are already in packed int4 format + # pack 4 int8(int4*2) to int32 + assert layer.weight.data.shape[-1] % 4 == 0, \ + f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view( + torch.int32).contiguous() + else: + # weights are not compressed + # need to be packed via npu_convert_weight_to_int4pack + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) + + +class AscendW4A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W4A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process + self.is_per_channel_weight = self.group_size == 0 + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + # NOTE: new quantize weights: 2 int4 pack into int8 + self.new_quant_version = quant_version == "1.0.0" + self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + if self.new_quant_version and self.tp_size > 16: + raise ValueError( + "The current weight does not support moe part tp>16.") + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + def get_weight(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + if self.new_quant_version: + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_sizes // 2 + else: + w13_output_size = 2 * intermediate_size_per_partition + w2_output_size = hidden_sizes + + param_dict["w13_weight"] = torch.empty(num_experts, + w13_output_size, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + def get_dynamic_quant_param(self, num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float32) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float32) + if not self.is_per_channel_weight: + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + + if self.new_quant_version: + param_dict["w13_scale_bias"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + param_dict["w2_scale_bias"] = torch.empty(num_experts, + hidden_sizes, + 16 // self.tp_size, + dtype=torch.float32) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like( + topk_ids, 0, global_num_experts - global_redundant_expert_num) + + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int4_w4a8=True, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + dynamic_eplb=self.dynamic_eplb) + + def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + scale = scale.transpose(1, 2).contiguous() + if self.is_per_channel_weight: + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() + group_num, k, n = weight.shape + # the weight of the new version is reduced by half by pack n, so it needs to be restored + if self.new_quant_version: + n = n * 2 + per_group_scale = per_group_scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = per_group_scale.shape + bias = None + if not self.new_quant_version: + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to( + torch.float32) + scale_fp32_np = scale_fp32.cpu().numpy() + scale_fp32_np.dtype = np.uint32 + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), + dtype=np.uint32) + + sscale_uint64[..., ::2] = scale_fp32_np + + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), + dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( + group_num, quantgroup_num, n) + sscale_uint64_tensor = sscale_uint64_tensor.npu() + return sscale_uint64_tensor, bias + + def update_bias(self, layer, w13_bias, w2_bias): + if self.new_quant_version: + layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + else: + w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + + def pack_to_int32(self, weight: torch.Tensor): + if self.new_quant_version: + # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + return weight.view(torch.int32).contiguous() + else: + return torch_npu.npu_quantize(weight.to(torch.float32), + torch.tensor([1.]).npu(), None, + torch.quint4x2, -1, False) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + + w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( + layer, "w13_weight_scale_second") else None + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( + layer, "w2_weight_scale_second") else None + layer.w13_weight_scale.data, w13_bias = self.process_scale( + layer.w13_weight, layer.w13_weight_scale.data, + w13_weight_scale_second) + layer.w2_weight_scale.data, w2_bias = self.process_scale( + layer.w2_weight, layer.w2_weight_scale.data, + w2_weight_scale_second) + if hasattr(layer, "w13_weight_scale_second"): + # scale_second is no longer used, release this part of the memory + del layer.w13_weight_scale_second + del layer.w2_weight_scale_second + del layer.w13_weight_offset_second + del layer.w2_weight_offset_second + + self.update_bias(layer, w13_bias, w2_bias) + + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_npu/quantization/w8a8.py b/vllm_npu/quantization/w8a8.py new file mode 100644 index 0000000..04fa07f --- /dev/null +++ b/vllm_npu/quantization/w8a8.py @@ -0,0 +1,674 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import torch +import torch_npu +from vllm.attention.backends.abstract import AttentionType +from vllm.distributed.parallel_state import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.ops.moe.experts_selector import select_experts +from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz + + +def quant_per_tensor(in_tensor: torch.Tensor, + input_scale: torch.Tensor, + input_offset: torch.Tensor, + function=False): + return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, + torch.qint8, -1, function) + + +class AscendW8A8LinearMethod: + """Linear method for Ascend W8A8. + + Args: + w_sym: whether the linear weight is symmetrically quantized. + """ + + def __init__(self) -> None: + # aclnn quant matmul requires to transpose matrix B, set to true by default. + self.transpose_weight = not is_310p() + + @staticmethod + def get_weight( + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.bfloat16, + ) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) + params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) + if params_dtype == torch.bfloat16: + params_dict["deq_scale"] = torch.empty(output_size, + dtype=torch.float32) + elif params_dtype == torch.float16: + params_dict["deq_scale"] = torch.empty(output_size, + dtype=torch.int64) + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + if x.dtype != torch.int8: + layer_cls_name = layer.__class__.__name__ + try: + weight_prefetch_method = get_forward_context( + ).weight_prefetch_method + except AssertionError: + weight_prefetch_method = None + + # prefetch qkvo_proj.weight preprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) + # quant + x = quant_per_tensor( + x, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + # prefetch qkvo_proj.weight postprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + layer_cls_name=layer_cls_name, + stop_flag=x, + ) + + quant_bias = layer.quant_bias if tp_rank == 0 else None + if is_310p(): + # On 300I Duo platform, we need transpose again if + # using nz. This transpose can be skipped in torchair. + output = torch_npu.npu_quant_matmul( + x, + layer.weight.data.transpose(1, 0), + layer.deq_scale, + bias=quant_bias, + output_dtype=layer.params_dtype, + ) + else: + output = torch_npu.npu_quant_matmul( + x, + layer.weight, + layer.deq_scale, + bias=quant_bias, + output_dtype=layer.params_dtype, + ) + return output + + def process_weights_after_loading(self, layer): + expanding_factor = layer.weight.data.shape[1] + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + layer.aclnn_input_offset = torch.nn.Parameter( + layer.input_offset.data.repeat(expanding_factor), + requires_grad=False).to(layer.aclnn_input_scale.dtype) + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + + +class AscendW8A8FusedMoEMethod: + """FusedMoe method for Ascend W8A8. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8, + requires_grad=False) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8, + requires_grad=False) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float16) + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float32) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float16) + param_dict["w2_deq_scale"] = torch.empty(num_experts, + hidden_sizes, + dtype=torch.float32) + param_dict["w13_deq_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32) + param_dict["w2_input_scale"] = torch.empty(num_experts, + 1, + dtype=torch.float32) + param_dict["w13_input_scale"] = torch.empty(num_experts, + 1, + dtype=torch.float32) + param_dict["w2_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + param_dict["w13_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + param_dict["quant_bias"] = torch.empty(num_experts, + hidden_sizes, + dtype=torch.int32) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + if is_310p(): + return fused_experts_310p(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w1_input_scale=layer.w13_input_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map) + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w1_input_scale=layer.w13_input_scale, + w1_input_offset=layer.w13_input_offset, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + w2_input_offset=layer.w2_input_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map) + + def process_weights_after_loading(self, layer): + if not is_310p(): + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( + layer.w13_weight_scale.data.shape[0], -1) + + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( + layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) + expanding_factor_w13 = layer.w13_weight.data.shape[1] + expanding_factor_w2 = layer.w2_weight.data.shape[1] + + if is_310p(): + layer.w13_input_scale.data = torch.nn.Parameter( + layer.w13_input_scale.data.max()) + layer.w2_input_scale.data = torch.nn.Parameter( + layer.w2_input_scale.data.max()) + else: + layer.w13_input_scale.data = torch.nn.Parameter( + layer.w13_input_scale.data.repeat(1, + expanding_factor_w13)[0:1]) + layer.w2_input_scale.data = torch.nn.Parameter( + layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]) + + layer.w13_input_offset.data = torch.nn.Parameter( + layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1]) + layer.w2_input_offset.data = torch.nn.Parameter( + layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]) + + # converting ACL_FORMAT_FRACTAL_NZ. + # npu_quant_grouped_matmul_dequant in eager mode does not accept + # ACL_FORMAT_FRACTAL_NZ. + if not is_310p(): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() + + +class AscendC8KVCacheMethod: + + def __init__(self) -> None: + self.antiquant_scale_comb = None + + @staticmethod + def create_weights(layer) -> None: + param_dict = {} # num_kv_heads * head_size + param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads * + layer.head_size, + dtype=torch.float16, + requires_grad=False) + param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads * + layer.head_size, + dtype=torch.float16, + requires_grad=False) + for weight_name, weight_param in param_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + layer.register_parameter(weight_name, param) + + def process_weights_after_loading(self, layer): + self.antiquant_scale_comb = torch.cat( + (layer.key_antiquant_scale.data.unsqueeze(0), + layer.value_antiquant_scale.data.unsqueeze(0)), + dim=0).to(torch.float16).contiguous() + + def apply(self, layer, query, key, value, kv_cache, attn_metadata, + attn_type, scale, output) -> torch.Tensor: + num_tokens = query.shape[0] + if attn_metadata is None: + return output.view(num_tokens, layer.num_heads * layer.head_size) + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + # C8 + quant_key = quant_per_tensor( + key.view(-1, layer.num_kv_heads * layer.head_size), + layer.key_antiquant_scale.data.view(-1), None, True) + quant_value = quant_per_tensor( + value.view(-1, layer.num_kv_heads * layer.head_size), + layer.value_antiquant_scale.data.view(-1), None, True) + + # View q k v to BSH. + query = query.view(-1, layer.num_heads, layer.head_size) + key = key.view(-1, layer.num_kv_heads, layer.head_size) + value = value.view(-1, layer.num_kv_heads, layer.head_size) + # TODO: Remove this contiguous in the future. + value = value.contiguous() + + if kv_cache[0].numel() > 0: + # if key_cache is None: + key_cache, value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + + block_size = key_cache.shape[1] + slots_indices = slots.reshape(-1, 1) + block_indices = slots_indices // block_size + slots_indices = slots_indices % block_size + indices = torch.cat((block_indices, slots_indices), dim=1) + + # C8 + torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key) + torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value) + + # V0-Style scheduler situation. + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=scale, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + out=output.reshape(query.shape)) + + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + raise NotImplementedError("kv cache int8 are not " + "implemented for " + "PrefillCacheHit") + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly + if hasattr(attn_metadata, "decode"): + # torch_air + decode_meta = attn_metadata.decode + seq_lens = decode_meta.seq_lens_list + else: + seq_lens = attn_metadata.seq_lens + block_size = key_cache.shape[1] + query = query.view(num_tokens, 1, layer.num_heads * + layer.head_size).contiguous() # changed + + # [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D] + key = key_cache + value = value_cache + + output = torch_npu.npu_incre_flash_attention( + query, + key, + value, + num_key_value_heads=layer.num_kv_heads, + num_heads=layer.num_heads, + actual_seq_lengths=seq_lens, + scale_value=scale, + input_layout='BSH', + block_size=block_size, + block_table=attn_metadata.block_tables, + antiquant_scale=self.antiquant_scale_comb, + ) + + # Normal V1 situation. + else: + raise NotImplementedError("kv cache int8 are not " + "implemented for " + "other case") + return output + + +def fused_experts_310p( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w1_input_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + global_num_experts: int, + expert_map: torch.Tensor = None, +) -> torch.Tensor: + ep_size = get_ep_group().world_size + local_num_experts = global_num_experts // ep_size + local_num_group = top_k // ep_size + + bsz, _ = hidden_states.shape + flatten_topk_ids = topk_ids.view(-1) + sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) + sorted_topk_ids = sorted_topk_ids.to(torch.int32) + sorted_hidden_states = hidden_states.index_select( + 0, sorted_topk_ids // local_num_group) + + experts_id = torch.arange(0, + local_num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( + torch.float32).sum(0) + topk_scales = topk_weights.view(-1).index_select( + 0, sorted_topk_ids).unsqueeze(-1) + group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) + + gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant( + x=sorted_hidden_states, + quantized_weight=w1, + weight_scale=w1_scale, + group_list=group_list, + x_scale=w1_input_scale, + quant_mode="pertensor") + + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + gate_up_out *= topk_scales + + down_out = torch_npu.npu_quant_grouped_matmul_dequant( + x=gate_up_out, + quantized_weight=w2, + weight_scale=w2_scale, + group_list=group_list, + x_scale=w2_input_scale, + quant_mode="pertensor") + + unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) + unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids) + final_hidden_states = unsorted_hidden_states.reshape( + bsz, top_k // ep_size, -1).sum(1) + + return final_hidden_states + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w1_input_scale: torch.Tensor, + w1_input_offset: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w2_input_offset: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + global_num_experts: int, + expert_map: torch.Tensor = None, +) -> torch.Tensor: + """ + Fused experts with top-k routing. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). + + Returns: + hidden_states: Hidden states after routing. + """ + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + """ + + original_dtype = hidden_states.dtype + ep_size = get_ep_group().world_size + local_num_experts = global_num_experts // ep_size + w1_input_scale, _ = w1_input_scale.max(0) + quant_sorted_hidden_states = quant_per_tensor( + hidden_states, + w1_input_scale, + None, + True, + ) + if expert_map is not None: + expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2( + quant_sorted_hidden_states, + topk_ids, + scale=None, + active_num=topk_ids.numel(), + expert_capacity=-1, + expert_num=local_num_experts, + drop_pad_mode=0, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + quant_mode=-1, + active_expert_range=[0, local_num_experts], + row_idx_type=0, + ) + + else: + raise NotImplementedError( + "The quantified version of MOE class models " + "currently does not support tensor parallelism") + if expanded_x.dtype != w1.dtype: + w1_input_scale, _ = w1_input_scale.max(0) + quant_sorted_hidden_states = quant_per_tensor( + expanded_x, + w1_input_scale, + None, + True, + ) + else: + quant_sorted_hidden_states = expanded_x + gate_up_out = torch_npu.npu_grouped_matmul( + x=[quant_sorted_hidden_states], + weight=[w1], + scale=[w1_scale * w1_input_scale[0]], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_count, + output_dtype=original_dtype, + )[0] + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + if gate_up_out.dtype != w2.dtype: + w2_input_scale, _ = w2_input_scale.max(0) + quant_gate_up_out = quant_per_tensor( + gate_up_out, + w2_input_scale, + None, + True, + ) + else: + quant_gate_up_out = gate_up_out + + down_out = torch_npu.npu_grouped_matmul( + x=[quant_gate_up_out], + weight=[w2], + scale=[w2_scale * w2_input_scale[0]], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_count, + output_dtype=original_dtype, + )[0] + + if expert_map is not None: + final_hidden_states = torch_npu.npu_moe_finalize_routing( + down_out, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights.to(down_out.dtype), + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + drop_pad_mode=2, + ) + else: + raise NotImplementedError( + "The quantified version of MOE class models " + "currently does not support tensor parallelism") + + return final_hidden_states diff --git a/vllm_npu/quantization/w8a8_dynamic.py b/vllm_npu/quantization/w8a8_dynamic.py new file mode 100644 index 0000000..dfec0ec --- /dev/null +++ b/vllm_npu/quantization/w8a8_dynamic.py @@ -0,0 +1,284 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +import torch_npu +from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.ops.moe.experts_selector import select_experts +from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz + + +class AscendW8A8DynamicLinearMethod: + """Linear method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + config = getattr(layer, "_ascend_quant_config", {}) + if not isinstance(x, tuple): + output_dtype = config.get("output_dtype", x.dtype) + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + else: + assert "output_dtype" in config.keys(), ( + f"DynamicLinearMethod needs explicitly specified `output_dtype`" + f"for pre-quantized input, got config [{config}]") + output_dtype = config["output_dtype"] + quantized_x, dynamic_scale = x + pertoken_scale = (dynamic_scale + if config.get("pertoken_scale", True) else None) + + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + pertoken_scale=pertoken_scale, + bias=bias, + output_dtype=output_dtype, + ) + return ((output, dynamic_scale) + if config.get("return_scale", False) else output) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # cast quantized weight tensors in NZ format for higher inference speed + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + + +class AscendW8A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + vllm_config = get_current_vllm_config() + ascend_config = get_ascend_config() + self.use_aclgraph = ( + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager + and not ascend_config.torchair_graph_config.enabled) + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like( + topk_ids, 0, global_num_experts - global_redundant_expert_num) + + if self.use_aclgraph: + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + expert_map=expert_map, + dynamic_eplb=self.dynamic_eplb, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num) + + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale.to(torch.float32), + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int8_w8a8=True, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + dynamic_eplb=self.dynamic_eplb) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( + layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( + layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) diff --git a/vllm_npu/sample/__init__.py b/vllm_npu/sample/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/sample/logits_processor/__init__.py b/vllm_npu/sample/logits_processor/__init__.py new file mode 100644 index 0000000..0dd2058 --- /dev/null +++ b/vllm_npu/sample/logits_processor/__init__.py @@ -0,0 +1,50 @@ +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING, Union + +import torch +from vllm.logger import init_logger +from vllm.v1.sample import logits_processor +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import LogitsProcessor +from vllm.v1.sample.logits_processor.state import LogitsProcessors + +from vllm_npu.sample.logits_processor.builtin import \ + AscendMinPLogitsProcessor + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + AscendMinPLogitsProcessor, +] + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs( + custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) diff --git a/vllm_npu/sample/logits_processor/builtin.py b/vllm_npu/sample/logits_processor/builtin.py new file mode 100644 index 0000000..f38d940 --- /dev/null +++ b/vllm_npu/sample/logits_processor/builtin.py @@ -0,0 +1,35 @@ +import torch +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import MinPLogitsProcessor + + +class AscendMinPLogitsProcessor(MinPLogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory) + + decode_max_num_seqs = getattr(vllm_config.scheduler_config, + 'decode_max_num_seqs', 0) + if decode_max_num_seqs != 0: + max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, + decode_max_num_seqs) + + self.min_p_count: int = 0 + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=is_pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + + self.use_double_tensor = torch.device(device).type != "cpu" + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs, ), dtype=torch.float32, device=device) + else: + self.min_p_device = self.min_p_cpu_tensor + # Current slice of the device tensor + self.min_p: torch.Tensor = self.min_p_device[:0] diff --git a/vllm_npu/sample/rejection_sampler.py b/vllm_npu/sample/rejection_sampler.py new file mode 100644 index 0000000..e0d770d --- /dev/null +++ b/vllm_npu/sample/rejection_sampler.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch.nn as nn +import vllm.v1.sample.rejection_sampler as rs +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import (RejectionSampler, compute_probs, + generate_uniform_probs) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +PLACEHOLDER_TOKEN_ID = -1 +GREEDY_TEMPERATURE = -1 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 32 + + +class AscendRejectionSampler(RejectionSampler, nn.Module): + """ + The implementation strictly follows the algorithm described in + https://arxiv.org/abs/2211.17192. + However, we want to clarify the terminology used in the implementation: + accepted tokens: tokens that are accepted based on the relationship + between the "raw" draft and target probabilities. + recovered tokens: tokens that are sampled based on the adjusted probability + distribution, which is derived from both the draft and target + probabilities. + bonus tokens: + If all proposed tokens are accepted, the bonus token is added to the + end of the sequence. The bonus token is only sampled from the target + probabilities. We pass in the bonus tokens instead of sampling them + in the rejection sampler to allow for more flexibility in the + sampling process. For example, we can use top_p, top_k sampling for + bonus tokens, while spec decode does not support these sampling + strategies. + output tokens: + Tokens are finally generated with the rejection sampler. + output tokens = accepted tokens + recovered tokens + bonus tokens + """ + + def forward( + self, + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_logits: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + ''' + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + NOTE: `target_logits` can be updated in place to save memory. + bonus_token_ids_tensor (torch.Tensor): + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling + process such as top_p, top_k sampling. + sampling_metadata (SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + ''' + assert metadata.max_spec_len <= MAX_SPEC_LEN + # [num_tokens, vocab_size] + # NOTE(woosuk): `target_logits` can be updated in place inside the + # `compute_probs` function. + target_probs = compute_probs( + target_logits, + metadata.cu_num_draft_tokens, + sampling_metadata, + ) + + output_token_ids = rejection_sample( + metadata.draft_token_ids, + metadata.num_draft_tokens, + metadata.max_spec_len, + metadata.cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + sampling_metadata, + ) + return output_token_ids + + +def rejection_sample( + # [num_tokens] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + max_spec_len: int, + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 + + batch_size = len(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) + + # Create output buffer. + output_token_ids = torch.empty( + (batch_size, max_spec_len + 1), + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. + device=device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + ) + else: + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + num_draft_tokens, + max_spec_len, + is_greedy, + ) + if sampling_metadata.all_greedy: + return output_token_ids + + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_probs = generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + # Rejection sampling for random sampling requests. + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) + return output_token_ids + + +def expand_batch_to_tokens( + x: torch.Tensor, # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. + Returns: + expanded_x: [num_tokens] tensor. + """ + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + expanded_x = x.new_empty(num_tokens) + expand_pytorch( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) + return expanded_x + + +def sample_recovered_tokens( + max_spec_len: int, + num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens] + draft_token_ids: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + # NOTE(woosuk): Create only one distribution for each request. + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + + recovered_token_ids = torch.empty_like(draft_token_ids) + sample_recovered_tokens_pytorch( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=draft_probs is None, + ) + return recovered_token_ids + + +def rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, # [batch_size, 2] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] +): + batch_size = output_token_ids.size(0) + num_tokens = draft_token_ids.size(0) + assert batch_size == num_tokens + accept_req_mask = draft_token_ids == target_argmax + output_token_ids[:, 0] = target_argmax + bonus_token_ids = bonus_token_ids.squeeze(1) + output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask] + + +def rejection_greedy_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] + draft_tokens_per_req, # [batch_size], list + max_spec_len, + is_greedy=None, # [batch_size] or None +): + batch_size = output_token_ids.size(0) + num_tokens = draft_token_ids.size(0) + device = output_token_ids.device + draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( + device, non_blocking=True) + if is_greedy is None: + is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) + + start_indices = cu_num_draft_tokens - draft_tokens_per_req + req_ids = torch.arange(batch_size, device=device) + token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) + token_positions = torch.arange( + num_tokens, device=device) - start_indices[token_req_ids] + + # Find the first mismatch position of each request. + mismatch_global = (draft_token_ids != target_argmax) + if max_spec_len == 0: + first_mismatch_pos_per_req = torch.zeros(batch_size, + dtype=torch.long, + device=device) + else: + # [bs, max_spec_len] + pos_matrix = torch.full((batch_size, max_spec_len), + -1, + dtype=torch.long, + device=device) + pos_matrix[token_req_ids, token_positions] = token_positions + mismatch_matrix = torch.full((batch_size, max_spec_len), + False, + dtype=torch.bool, + device=device) + mismatch_matrix[token_req_ids, token_positions] = mismatch_global + mismatch_positions = torch.where(mismatch_matrix, pos_matrix, + max_spec_len * 2) + first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) + no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) + first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ + no_mismatch_mask] + + # Copy matched target tokens into output. + copy_len = torch.minimum(first_mismatch_pos_per_req + 1, + draft_tokens_per_req) + copy_indices = torch.arange(max_spec_len + 1, + device=device).expand(batch_size, -1) + copy_mask = copy_indices < copy_len.unsqueeze(1) + greedy_mask = is_greedy.unsqueeze(1) + final_copy_mask = copy_mask & greedy_mask + global_idx = start_indices.unsqueeze(1) + copy_indices + output_token_ids[final_copy_mask] = target_argmax[ + global_idx[final_copy_mask]].to(output_token_ids.dtype) + # Fill bonus token. + needs_bonus = is_greedy & (first_mismatch_pos_per_req + >= draft_tokens_per_req) + if torch.any(needs_bonus): + bonus_rows = torch.where(needs_bonus)[0] + bonus_cols = draft_tokens_per_req[bonus_rows] + bonus_token_ids = bonus_token_ids.squeeze(1) + output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows] + + +def rejection_random_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + + for req_idx in range(batch_size): + if is_greedy[req_idx]: + continue + + if req_idx == 0: + start_idx = 0 + else: + start_idx = cu_num_draft_tokens[req_idx - 1].item() + end_idx = cu_num_draft_tokens[req_idx].item() + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = draft_token_ids[start_idx + pos].item() + + if IS_NGRAM: + draft_prob = 1.0 + else: + draft_prob = draft_probs[start_idx + pos, + draft_token_id].item() + + target_prob = target_probs[start_idx + pos, + draft_token_id].item() + uniform_prob = uniform_probs[start_idx + pos].item() + + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + token_id = draft_token_id + else: + rejected = True + token_id = recovered_token_ids[start_idx + pos].item() + + output_token_ids[req_idx, pos] = token_id + + if not rejected: + bonus_token_id = bonus_token_ids[req_idx].item() + output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + + +def expand_pytorch( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS, +): + batch_size = len(input_ptr) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1] + end_idx = cu_num_tokens_ptr[req_idx] + num_tokens = end_idx - start_idx + + src_val = input_ptr[req_idx] + src_val = replace_to if src_val == replace_from else src_val + + offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device) + mask = offset < num_tokens + + output_slice = start_idx + offset[mask] + output_ptr[output_slice] = src_val + + +def sample_recovered_tokens_pytorch( + output_token_ids, # [num_tokens] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + q, # [batch_size, vocab_size] + vocab_size, + IS_NGRAM=False, +): + batch_size = len(cu_num_draft_tokens) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] + end_idx = cu_num_draft_tokens[req_idx] + num_draft_tokens = end_idx - start_idx + + for pos in range(num_draft_tokens): + token_idx = start_idx + pos + + if IS_NGRAM: + draft_token_id = draft_token_ids[token_idx] + orig_prob = target_probs[token_idx, draft_token_id].item() + target_probs[token_idx, draft_token_id] = 0 + prob = target_probs[token_idx].clone() + else: + draft_p = draft_probs[token_idx].clone() + target_p = target_probs[token_idx].clone() + prob = torch.maximum(target_p - draft_p, + torch.tensor(0.0, device=target_p.device)) + + q_values = torch.full((vocab_size, ), + float('-inf'), + device=q.device) + q_values[:vocab_size] = q[req_idx, :vocab_size] + + recovered_id = torch.argmax(prob / q_values).item() + output_token_ids[token_idx] = recovered_id + + if IS_NGRAM: + target_probs[token_idx, draft_token_id] = orig_prob + + +rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_npu/sample/sampler.py b/vllm_npu/sample/sampler.py new file mode 100644 index 0000000..0d8d38a --- /dev/null +++ b/vllm_npu/sample/sampler.py @@ -0,0 +1,74 @@ +import torch +import torch_npu +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample +from vllm.v1.sample.sampler import Sampler + +from vllm_npu.utils import is_310p + +DEFAULT_LOGPROBS_MODE = "raw_logprobs" + + +class AscendSampler(Sampler): + + def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE): + # TODO: support logprobs_mode in vllm-ascend + super().__init__(logprobs_mode=logprobs_mode) + self.topk_topp_sampler = AscendTopKTopPSampler() + + +class AscendTopKTopPSampler(TopKTopPSampler): + + def _apply_top_k_top_p( + self, + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, + ) -> torch.Tensor: + if p is None and k is None: + return logits + # npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P + if not is_310p(): + # npu_top_k_top_p requires parameter k ranged from 1 to 1024 + if k is None or 1 <= int(k.max()) <= 1024: + # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) + return torch_npu.npu_top_k_top_p(logits, p, k) + + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to( + torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + def forward_native(self, logits, generators, k, p): + """Override pytorch native implementation to torch_npu""" + logits = self._apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == "processed_logits": + logits_to_return = logits + elif self.logprobs_mode == "processed_logprobs": + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) + + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators), logits_to_return diff --git a/vllm_npu/spec_decode/__init__.py b/vllm_npu/spec_decode/__init__.py new file mode 100644 index 0000000..fc0ee40 --- /dev/null +++ b/vllm_npu/spec_decode/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py +# +from vllm_npu.spec_decode.eagle_proposer import EagleProposer +from vllm_npu.spec_decode.mtp_proposer import MtpProposer +from vllm_npu.spec_decode.ngram_proposer import NgramProposer + + +def get_spec_decode_method(method, vllm_config, device, runner): + if method == "ngram": + return NgramProposer(vllm_config, device, runner) + elif method in ["eagle", "eagle3"]: + return EagleProposer(vllm_config, device, runner) + elif method == 'deepseek_mtp': + return MtpProposer(vllm_config, device, runner) + else: + raise ValueError("Unknown speculative decoding method: " + f"{method}") diff --git a/vllm_npu/spec_decode/eagle_proposer.py b/vllm_npu/spec_decode/eagle_proposer.py new file mode 100644 index 0000000..8061b1b --- /dev/null +++ b/vllm_npu/spec_decode/eagle_proposer.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from vllm.attention.layer import Attention +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_npu.ascend_forward_context import set_ascend_forward_context +from vllm_npu.attention.attention_mask import AttentionMaskBuilder +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.attention.utils import AscendCommonAttentionMetadata +from vllm_npu.spec_decode.interface import Proposer, SpecDcodeType + +PADDING_SLOT_ID = -1 + + +class EagleProposer(Proposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3 + self.vllm_config = vllm_config + self.device = device + self.runner = runner + + self.block_size = vllm_config.cache_config.block_size + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size( + ) + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=device) + self.positions = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (self.vllm_config.scheduler_config.max_num_batched_tokens, + self.hidden_size), + dtype=self.vllm_config.model_config.dtype, + device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=device, + dtype=torch.int32) + attn_mask_len = self.vllm_config.model_config.max_model_len + self.attn_mask_builder = AttentionMaskBuilder( + attn_mask_len, self.vllm_config.model_config.dtype, device=device) + + def load_model(self, model: nn.Module) -> None: + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + self.attn_layer_name = next(iter(draft_attn_layer_names)) + + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1: + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = model.model.embed_tokens + else: + logger.info( + "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"): + logger.info("Loading EAGLE LM head weights from the target model.") + if supports_multimodal(model): + self.model.lm_head = model.get_language_model().lm_head + else: + self.model.lm_head = model.lm_head + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): + moe_comm_type = self.runner._select_moe_comm_method( + num_tokens, with_prefill) + with set_ascend_forward_context(None, + self.vllm_config, + moe_comm_type=moe_comm_type, + num_tokens=num_tokens): + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + dummy_compute_logits(self.hidden_states) + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + + attn_metadata = self._get_eagle_atten_dict(scheduler_output) + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + eagle_attn_metadata = attn_metadata[self.attn_layer_name] + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc + else: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + cu_num_tokens, token_indices = self._prepare_inputs( + eagle_attn_metadata.query_start_loc, num_rejected_tokens, + num_tokens) + target_token_ids = self.runner.input_ids[token_indices] + target_positions = positions[token_indices] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] + + draft_token_ids = self._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=eagle_attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _get_eagle_atten_dict( + self, + scheduler_output: "SchedulerOutput", + ): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.runner.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.runner.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.runner.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + self.runner.query_lens = torch.from_numpy(num_scheduled_tokens) + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.runner.arange_np[:num_reqs], + num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + + # Get positions. + positions_np = self.runner.positions_np[:total_num_scheduled_tokens] + np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.runner.uses_mrope: + self.runner._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = ( + positions_np + + req_indices * self.runner.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select( + self.runner.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.runner.input_ids_cpu[:total_num_scheduled_tokens]) + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.runner.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table = self.runner.input_batch.block_table[ + kv_cache_group_id] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.runner.query_start_loc_np[0] = 0 + self.runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + + self.runner.seq_lens_np[:num_reqs] = ( + self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + + # Copy the tensors to the NPU. + self.runner.input_ids[:total_num_scheduled_tokens].copy_( + self.runner.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + if self.runner.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.runner.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.runner. + mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.runner.positions[:total_num_scheduled_tokens].copy_( + self.runner.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + self.runner.query_start_loc[:num_reqs + 1].copy_( + self.runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.runner.seq_lens[:num_reqs].copy_( + self.runner.seq_lens_cpu[:num_reqs], non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.runner.seq_lens[num_reqs:].fill_(0) + self.runner.query_start_loc[num_reqs + 1:].fill_(-1) + + attn_metadata = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.runner.kv_cache_config.kv_cache_groups): + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + num_reqs=num_reqs, + max_query_len=max_num_scheduled_tokens, + num_actual_tokens=total_num_scheduled_tokens, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=self.runner.slot_mapping, + positions=self.runner.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_i = builder.build(0, common_attn_metadata, + self.runner.get_model()) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + return attn_metadata + + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.runner.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def _propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + device = cu_num_tokens.device + cu_num_tokens = cu_num_tokens.cpu() + block_table = block_table.cpu() + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + target_positions = target_positions.cpu() + if self.name == SpecDcodeType.EAGLE3: + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + seq_lens = (target_positions[last_token_indices] + 1).int() + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + attn_mask = self.runner.attn_mask + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens.to(device), + query_start_loc_cpu=cu_num_tokens, + seq_lens_cpu=seq_lens.cpu(), + max_query_len=max_query_len, + num_reqs=batch_size, + num_actual_tokens=num_tokens, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=target_slot_mapping, + positions=target_positions, + attn_mask=attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata = builder.build(0, common_attn_metadata, + self.runner.get_model()) + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + + with_prefill = attn_metadata.attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + moe_comm_type = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions.to(device) + self.hidden_states[:num_tokens] = target_hidden_states + attn_metadata.block_tables = block_table.to(device) + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + moe_comm_type=moe_comm_type, + num_tokens=num_input_tokens): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + ) + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.vllm_config.speculative_config.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + # Generate the remaining draft tokens. + draft_token_ids_tensor = torch.zeros( + (self.vllm_config.speculative_config.num_speculative_tokens, + *draft_token_ids.shape), + dtype=draft_token_ids.dtype) + draft_token_ids_tensor[0] = draft_token_ids + + positions_cpu = target_positions[last_token_indices].cpu().to( + torch.int64) + hidden_states = hidden_states[last_token_indices] + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size + + moe_comm_type = self.runner._select_moe_comm_method( + input_batch_size, False) + + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size + 1] + attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[ + 1:].tolist() + attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size + attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens + query_lens.fill_(1) + attn_metadata.query_lens = query_lens + + attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] + attn_metadata.seq_lens_list = seq_lens.tolist() + attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + for now_speculative in range( + self.vllm_config.speculative_config.num_speculative_tokens - + 1): + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_tensor[now_speculative].to(device) + positions_cpu += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions_cpu >= self.vllm_config.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, + positions_cpu) + clamped_positions = clamped_positions_cpu.to(device) + + # TODO: Increment the sequence lengths. + + attn_metadata.seq_lens += 1 + attn_metadata.seq_lens_list = [ + _ + 1 for _ in attn_metadata.seq_lens_list + ] + # TODO: Consider max model length. + # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + # self.max_model_len) + # For the requests that exceed the max model length, we set the + # TODO: sequence length to 1 to minimize their overheads in attention. + + # Compute the slot mapping. + block_numbers = (clamped_positions_cpu // self.block_size) + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + slot_mapping_cpu = ( + block_ids * self.vllm_config.cache_config.block_size + + clamped_positions_cpu % self.block_size) + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping_cpu.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + # NOTE: ASCEND slot_mapping must on cpu + attn_metadata.slot_mapping = slot_mapping_cpu.to( + torch.int32).to(device) + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + attn_metadata.seq_lens, positions_cpu, + self.vllm_config.model_config.dtype, self.device) + + attn_metadata.attn_mask = attn_mask + attn_metadata.block_tables = block_table.to(device) + # Run the model. + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + moe_comm_type=moe_comm_type, + num_tokens=input_batch_size): + + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + ) + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size]) + + # TODO(wenlong): get more than one token for tree attention + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() + + # [batch_size, num_speculative_tokens] + draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) + return draft_token_ids + + def _prepare_inputs( + self, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + num_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + # [a - n1, b - n2, c - n3] -> + # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + cu_num_tokens = torch.zeros_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_target_query_lens.device, + ) + BLOCK_SIZE = 1024 + self._prepare_eagle_input_sequential( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, + block_size: int): + num_programs = len(cu_num_tokens) - 1 + for pid in range(num_programs): + start_pos = cu_num_tokens[pid].item() + end_pos = cu_num_tokens[pid + 1].item() + num_tokens = end_pos - start_pos + index_start = cu_query_lens[pid].item() + num_blocks = int( + torch.ceil(torch.tensor(num_tokens / block_size)).item()) + + for i in range(num_blocks): + offset_tensor = torch.arange(0, + block_size, + dtype=torch.int32, + device=out_tensor.device) + global_start_offset = i * block_size + target_indices = torch.tensor( + start_pos + global_start_offset, + dtype=torch.int32, + device=out_tensor.device) + offset_tensor + values_to_store = torch.tensor( + index_start + global_start_offset, + dtype=torch.int32, + device=out_tensor.device) + offset_tensor + mask = (target_indices >= start_pos) & \ + (target_indices < end_pos) & \ + (offset_tensor < num_tokens) + out_tensor[target_indices[mask]] = values_to_store[mask] diff --git a/vllm_npu/spec_decode/interface.py b/vllm_npu/spec_decode/interface.py new file mode 100644 index 0000000..ad4e751 --- /dev/null +++ b/vllm_npu/spec_decode/interface.py @@ -0,0 +1,54 @@ +import enum +from typing import Optional + +import torch +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + + +class SpecDcodeType(enum.Enum): + NGRAM = 0 + EAGLE = 1 + EAGLE3 = 2 + MTP = 4 + + +class Proposer: + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device = None, + runner=None): + pass + + def load_model(self, model): + """Called by load_model in model_runner""" + raise NotImplementedError + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): + """Called by dummy_run in modle_runner""" + raise NotImplementedError + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + """Called by execute_model in model_runner""" + raise NotImplementedError diff --git a/vllm_npu/spec_decode/mtp_proposer.py b/vllm_npu/spec_decode/mtp_proposer.py new file mode 100644 index 0000000..2bd08b8 --- /dev/null +++ b/vllm_npu/spec_decode/mtp_proposer.py @@ -0,0 +1,675 @@ +import types + +import torch +import torch.nn as nn +import torchair +from torchair import patch_for_hcom +from vllm.attention.layer import Attention +from vllm.config import (CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, set_default_torch_dtype) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.ascend_forward_context import set_ascend_forward_context +from vllm_npu.attention.utils import AscendCommonAttentionMetadata +from vllm_npu.patch.worker.patch_deepseek_mtp import \ + AscendDeepSeekMTP as DeepSeekMTP +from vllm_npu.spec_decode.interface import Proposer, SpecDcodeType +from vllm_npu.torchair.models.torchair_deepseek_mtp import \ + TorchairDeepSeekMTP +from vllm_npu.torchair.utils import (TORCHAIR_CACHE_DIR, + TorchairCommonAttentionMetadata) +from vllm_npu.utils import ProfileExecuteDuration, lmhead_tp_enable + +PADDING_SLOT_ID = -1 + + +class MtpProposer(Proposer): + + def __init__( + self, + vllm_config: VllmConfig, + device, + runner, + ): + self.name = SpecDcodeType.MTP + self.vllm_config = vllm_config + self.device = device + self.runner = runner + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + + # persistent buffers for graph + self.input_ids = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int64, + device=self.device) + self.hidden_states = torch.zeros( + (self.runner.max_num_tokens, + vllm_config.model_config.get_hidden_size()), + dtype=self.runner.dtype, + device=self.device) + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore + self.torchair_graph_enabled = get_ascend_config( + ).torchair_graph_config.enabled + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=self.runner.device, + dtype=torch.int32) + self.use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + + def load_model(self, model) -> None: + loader = get_model_loader(self.vllm_config.load_config) + + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + target_device = self.vllm_config.device_config.device + + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + if self.torchair_graph_enabled or ( + self.enable_shared_expert_dp + and self.vllm_config.model_config.use_mla): + self.model = TorchairDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + else: + self.model = DeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = list(draft_attn_layer_names) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None) -> None: + if not self.torchair_graph_enabled: + # TODO: adapt enable_dbo later + (num_tokens, num_tokens_across_dp, with_prefill, + _) = self.runner._sync_metadata_across_dp(num_tokens, + with_prefill, False) + + moe_comm_type = self.runner._select_moe_comm_method( + num_tokens, with_prefill) + + is_running_torchair = self.torchair_graph_enabled and \ + not with_prefill + + if is_running_torchair: + skip_attn = False + if skip_attn: + attn_metadata = None + else: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) + + input_ids = self.input_ids[:num_tokens] + positions = self.positions[:num_tokens] + previous_hidden_states = self.hidden_states[:num_tokens] + for _ in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor): + if is_running_torchair: + assert attn_metadata is not None + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(previous_hidden_states) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + torchair_compiled_model( + input_ids=input_ids, + positions=positions, + hidden_states=previous_hidden_states, + inputs_embeds=None, + intermediate_tensors=None, + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:], + spec_step_idx=0) + else: + self.model(input_ids=input_ids, + positions=positions, + hidden_states=previous_hidden_states) + dummy_compute_logits(previous_hidden_states) + if with_prefill: + break + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + accepted_token_indices = None + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, accepted_token_indices, target_token_ids, \ + target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + self.runner.input_ids[:num_scheduled_tokens], + positions[:num_scheduled_tokens], + hidden_states[:num_scheduled_tokens], + attn_metadata.slot_mapping[:num_scheduled_tokens], + is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(), + ) + + draft_token_ids = self._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + token_indices=accepted_token_indices) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _prepare_inputs( + self, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + token_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + slot_mapping: torch.Tensor, + is_torchair_graph: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + if is_torchair_graph: + cu_num_tokens = cu_target_query_lens + relative_index = query_len_per_req - num_rejected_tokens - 1 + token_indices = cu_num_tokens[:-1] + relative_index + # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model + target_token_ids = token_ids + target_positions = positions + target_hidden_states = hidden_states + target_slot_mapping = slot_mapping + else: + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.zeros( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + BLOCK_SIZE = 1024 + self._prepare_input_kernel( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + target_token_ids = token_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = slot_mapping[token_indices] + return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping + + def _propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + token_indices=None) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + if token_indices is not None and self.torchair_graph_enabled: + last_token_indices = token_indices + + self.input_ids[last_token_indices] = next_token_ids + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + # FIXME: reorder_batch() needs to be called before build() + # because fields of attn_metadata_builder needs to be updated. + # However, currently reorder_batch() takes input_batch and + # scheduler_output as arguments, we should probably refactor + # the method to use new data structures which are independent + # from input_batch and scheduler_output. + # self.runner.attn_metadata_builder.reorder_batch( + # input_batch=self.runner.input_batch, + # scheduler_output=self.runner.scheduler_output, + # ) + is_running_torchair = self.torchair_graph_enabled and \ + not self.runner.with_prefill + + if is_running_torchair: + # Torchair graph mode, padding is same as the main model + num_input_tokens = self.runner.graph_pad_size + elif (self.runner.use_aclgraph + and num_tokens <= self.runner.aclgraph_batch_sizes[-1]): + # Acl graph mode, add padding to the batch size + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + # Eager mode, no padding needed + num_input_tokens = num_tokens + + seq_lens = target_positions[last_token_indices] + 1 + seq_lens = seq_lens.int() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens[:batch_size + 1], + query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), + seq_lens_cpu=seq_lens.cpu(), + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + graph_pad_size=self.runner.graph_pad_size, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + + if not self.torchair_graph_enabled: + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build(0, common_attn_metadata, + self.runner.get_model()) + + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp + + else: + attn_metadata = self.runner.attn_metadata_builder.build( + 0, common_attn_metadata, self.runner.get_model()) + + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + + if not self.torchair_graph_enabled: + # torch mode need to update num_tokens_across_dp + # TODO: adapt enable_dbo later + (num_input_tokens, num_tokens_across_dp, with_prefill, + _) = self.runner._sync_metadata_across_dp( + num_input_tokens, self.runner.with_prefill, False) + else: + # torchair mode can reuse self.runner.num_tokens_across_dp + num_tokens_across_dp = self.runner.num_tokens_across_dp + with_prefill = self.runner.with_prefill + + moe_comm_type = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=False) + aclgraph_runtime_mode, batch_descriptor = \ + self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) + + for step in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=num_tokens): + with ProfileExecuteDuration().capture_async('mtp_forward'): + model_kwargs = {} + model_kwargs["attn_metadata"] = attn_metadata + if self.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] + if is_running_torchair: + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_input_tokens) + hidden_states = torchair_compiled_model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self. + hidden_states[:num_input_tokens], + inputs_embeds=None, + intermediate_tensors=None, + spec_step_idx=0, + **model_kwargs) + else: + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens] + ) + + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable(): + if not self.runner.with_prefill: + max_num_reqs_across_dp = num_input_tokens + else: + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, + (0, max_num_reqs_across_dp - num_indices)) + + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + if lmhead_tp_enable() and num_indices < logits.shape[0]: + logits = logits[:num_indices] + last_token_indices = last_token_indices[:num_indices] + draft_token_ids = logits.argmax(dim=-1) + + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + if step == 0: + draft_token_ids_list = [draft_token_ids] + else: + draft_token_ids_list.append(draft_token_ids) + + # prepare next mtp inputs + # mtp>1: prefill skip or decode skip last loop + if with_prefill: + for _ in range(self.num_speculative_tokens - 1): + draft_token_ids_list.append(draft_token_ids) + if step == self.num_speculative_tokens - 1 or with_prefill: + break + + if not self.torchair_graph_enabled: + attn_metadata_i = attn_metadata[self.attn_layer_name[0]] + else: + attn_metadata_i = attn_metadata + + if step == 0: + positions = target_positions[last_token_indices] + hidden_states = hidden_states[last_token_indices] + slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] + attn_metadata_i.slot_mapping.fill_(-1) + attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] + last_token_indices = self.arange[:batch_size] + if attn_metadata_i.num_decode_tokens != 0: + attn_metadata_i.num_decode_tokens = batch_size + if is_running_torchair: + attn_metadata_i.num_actual_tokens = batch_size + attn_metadata_i.query_lens = [1] * batch_size + + input_ids = draft_token_ids_list[-1].int() + positions += 1 + + if not self.torchair_graph_enabled: + attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + 1:batch_size + 1].tolist() + attn_metadata_i.decode.cos = builder.cos_cache[ + positions].unsqueeze(1).unsqueeze(2) + attn_metadata_i.decode.sin = builder.sin_cache[ + positions].unsqueeze(1).unsqueeze(2) + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.runner.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + # Increment the sequence lengths. + attn_metadata_i.seq_lens[:batch_size] += 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + exceeds_max_model_len_cpu = exceeds_max_model_len.to( + attn_metadata_i.seq_lens.device, non_blocking=False) + attn_metadata_i.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len_cpu, 1) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping += 1 + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:hidden_states.shape[0]] = hidden_states + attn_metadata_i.slot_mapping[:batch_size] = slot_mapping + + if attn_metadata_i.prefill is not None: + attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( + ) + attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.input_positions = self.positions[: + num_input_tokens] + attn_metadata_i.prefill.max_seq_lens += 1 + attn_metadata_i.prefill.max_seq_lens = min( + attn_metadata_i.prefill.max_seq_lens, + self.runner.model_config.max_model_len) + if attn_metadata_i.decode is not None: + attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( + ) + attn_metadata_i.decode.input_positions = self.positions[: + num_input_tokens] + attn_metadata_i.decode.max_seq_lens += 1 + attn_metadata_i.decode.max_seq_lens = min( + attn_metadata_i.decode.max_seq_lens, + self.runner.model_config.max_model_len) + + # mtp>1: [batch_size, k] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ + -1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.runner.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + config.experimental_config.enable_view_optimize = \ + get_ascend_config().torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.runner.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=not self.use_sparse, + fullgraph=True, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=not self.use_sparse, + fullgraph=True, + cache_dir=TORCHAIR_CACHE_DIR, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] + + # TODO Using torch instead of triton may result in poor performance + def _prepare_input_kernel(self, out_ptr: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, block_size: int): + device = cu_query_lens.device + dtype = out_ptr.dtype + + offsets = torch.arange(block_size, device=device, dtype=dtype) + start_pos = cu_num_tokens[:-1] + end_pos = cu_num_tokens[1:] + num_tokens = end_pos - start_pos + + global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) + values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) + + mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) + + global_indices_flat = global_indices[mask] + values_flat = values[mask] + out_ptr[global_indices_flat] = values_flat diff --git a/vllm_npu/spec_decode/ngram_proposer.py b/vllm_npu/spec_decode/ngram_proposer.py new file mode 100644 index 0000000..c5aa0a9 --- /dev/null +++ b/vllm_npu/spec_decode/ngram_proposer.py @@ -0,0 +1,72 @@ +import torch +from vllm.config import CUDAGraphMode +from vllm.v1.spec_decode.ngram_proposer import \ + NgramProposer as VllmNgramProposer + +from vllm_npu.spec_decode.interface import Proposer, SpecDcodeType + + +class NgramProposer(VllmNgramProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.NGRAM + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + valid_ngram_requests = [] + for i, sampled_ids in enumerate(valid_sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + continue + + req_id = self.runner.input_batch.req_ids[i] + if req_id in self.runner.input_batch.spec_decode_unsupported_reqs: + continue + + num_tokens = self.runner.input_batch.num_tokens_no_spec[i] + if num_tokens >= self.runner.input_batch.max_model_len: + # Skip requests that have already reached the max model length. + continue + + start_idx = self.runner.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + num_sampled_ids + self.runner.input_batch.token_ids_cpu[ + i, start_idx:end_idx] = sampled_ids + + valid_ngram_requests.append(i) + + draft_token_ids = self.batch_propose( + len(valid_sampled_token_ids), + valid_ngram_requests, + self.runner.input_batch.num_tokens_no_spec, + self.runner.input_batch.token_ids_cpu, + ) + + return draft_token_ids \ No newline at end of file diff --git a/vllm_npu/torchair/__init__.py b/vllm_npu/torchair/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/torchair/models/__init__.py b/vllm_npu/torchair/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/torchair/models/qwen2.py b/vllm_npu/torchair/models/qwen2.py new file mode 100644 index 0000000..daaf4ff --- /dev/null +++ b/vllm_npu/torchair/models/qwen2.py @@ -0,0 +1,363 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from collections.abc import Iterable +from typing import Any, List, Optional, Union + +import torch +import torch.nn.functional as F +import vllm +import vllm.envs as envs +from torch import nn +from transformers import Qwen2Config +from vllm.attention import AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401 +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401 +from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) +from vllm.sequence import IntermediateTensors + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.attention.attention_v1 import AscendAttentionState + + +def all_gather_and_maybe_unpad( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + if pad_size > 0: + return hidden_states[:-pad_size, :] + return hidden_states + + +def maybe_pad_and_reduce_scatter( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + if pad_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) + hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) + return hidden_states + + +class CustomQwen2Attention(Qwen2Attention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + q, k = self.rotary_emb(positions, + q, + k, + is_prefill=False, + is_qwen_torchair=True) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = q.shape + output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = output + + attn_output = self.attn.impl.forward(self.attn, + q, + k, + v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + trace_flag=False, + **forward_kwargs) + output, _ = self.o_proj(attn_output) + return output + else: + if type(self.rotary_emb) is RotaryEmbedding: + q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) + else: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen2Model(Qwen2Model): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + kv_cache = kv_caches[i - self.start_layer] \ + if kv_caches is not None else None + hidden_states, residual = layer(positions, + hidden_states, + residual, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen2Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, # type: ignore + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM diff --git a/vllm_npu/torchair/models/qwen3_moe.py b/vllm_npu/torchair/models/qwen3_moe.py new file mode 100644 index 0000000..13e2161 --- /dev/null +++ b/vllm_npu/torchair/models/qwen3_moe.py @@ -0,0 +1,537 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/models/qwen3_moe.py +# This file is a part of the vllm-ascend project. +from typing import Any, List, Optional, Union + +import torch +import vllm.envs as envs +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, CompilationLevel, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import (MixtureOfExperts, + SupportsLoRA, SupportsPP) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeMLP, Qwen3MoeModel, + Qwen3MoeSparseMoeBlock) +from vllm.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.torchair.ops.sequence_parallel import (MetadataForPadding, + init_metadata_for_sp) +from vllm_npu.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE + + +class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = TorchairAscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states, + attn_metadata=None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ): + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + _metadata_for_padding=_metadata_for_padding, + ) + + return hidden_states + + +class CustomQwen3MoeAttention(Qwen3MoeAttention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + @staticmethod + def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int, + head_dim: int, q_norm, k_norm): + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm(k_by_head) + k = k_by_head.view(k.shape) + + return q, k, v + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size, + self.head_dim, self.q_norm, self.k_norm) + + if (self.torchair_graph_enabled and attn_metadata is not None and + attn_metadata.attn_state == AscendAttentionState.DecodeOnly): + q, k = self.rotary_emb(positions, + q, + k, + is_prefill=False, + is_qwen_torchair=True) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = q.shape + output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = output + + attn_output = self.attn.impl.forward(self.attn, + q, + k, + v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + trace_flag=False, + **forward_kwargs) + output, _ = self.o_proj(attn_output) + return output + else: + q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + vllm_config: Optional[VllmConfig] = None, + prefix: str = "", + ) -> None: + + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = CustomQwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + self.use_aclgraph = (vllm_config is not None + and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + if not self.use_aclgraph: + # FIXME: custom sparse moe block doesn't work with aclgraph. + self.mlp = CustomSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.enable_sequence_parallelism = ( + vllm_config.compilation_config.pass_config. + enable_sequence_parallelism if vllm_config is not None else False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> torch.Tensor: + + # To prevent precision issues during the decoder phase when only prefilling enables SP + if not self.enable_sequence_parallelism: + self.self_attn.o_proj.reduce_results = True + else: + self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True + + # Self Attention + if residual is None: + residual = hidden_states + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + residual = _metadata_for_padding.padding_slice(residual) + + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned( + hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter( + hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if not self.use_aclgraph: + hidden_states = self.mlp( + hidden_states, _metadata_for_padding=_metadata_for_padding) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class CustomQwen3MoeModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + parallel_config = vllm_config.parallel_config + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomQwen3MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + vllm_config=vllm_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, + _metadata_for_padding=_metadata_for_padding) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned( + hidden_states) + + return hidden_states + + +class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + SupportsPP.__init__(self) + SupportsLoRA.__init__(self) + MixtureOfExperts.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomQwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism + # Set MoE hyperparameters + self.expert_weights: list[torch.Tensor] = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3MoeDecoderLayer) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3MoE layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + _metadata_for_padding = init_metadata_for_sp( + input_ids, self.enable_sequence_parallelism) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds, _metadata_for_padding) + return hidden_states diff --git a/vllm_npu/torchair/models/torchair_deepseek_mtp.py b/vllm_npu/torchair/models/torchair_deepseek_mtp.py new file mode 100644 index 0000000..78ff817 --- /dev/null +++ b/vllm_npu/torchair/models/torchair_deepseek_mtp.py @@ -0,0 +1,218 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/deepseek_mtp.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, + SharedHead) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.sequence import IntermediateTensors + +from vllm_npu.torchair.models.torchair_deepseek_v2 import \ + TorchairDeepseekV2DecoderLayer + + +class TorchairDeepSeekShareHead(SharedHead): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + nn.Module.__init__(self) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) + + +class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer + ): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + + self.tp_size = get_tensor_model_parallel_world_size() + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = TorchairDeepSeekShareHead(config=config, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, + "shared_head")) + self.mtp_block = TorchairDeepseekV2DecoderLayer( + config, prefix, model_config, cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where((positions == 0).unsqueeze(-1), + torch.zeros_like(inputs_embeds), + inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 + + hidden_states, residual = self.mtp_block( + positions=positions, + hidden_states=hidden_states, + residual=None, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + replace_allreduce=replace_allreduce) + hidden_states = residual + hidden_states + return hidden_states + + +class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + TorchairDeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + # Note: torch._dynamo.exc.Unsupported: builtin: str + self.layers_list = [ + self.layers[str(idx)] + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + ] + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + step_kv_cache = kv_caches[ + current_step_idx] if kv_caches is not None else None + return self.layers_list[current_step_idx]( + input_ids, + positions, + step_kv_cache, + attn_metadata, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers_list[current_step_idx] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states)) + return logits + + +class TorchairDeepSeekMTP(DeepSeekMTP): + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.model_config.hf_config + self.model = TorchairDeepSeekMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states diff --git a/vllm_npu/torchair/models/torchair_deepseek_v2.py b/vllm_npu/torchair/models/torchair_deepseek_v2.py new file mode 100644 index 0000000..d1e02dc --- /dev/null +++ b/vllm_npu/torchair/models/torchair_deepseek_v2.py @@ -0,0 +1,1301 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch_npu +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.distributed.parallel_state import get_dp_group, get_ep_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, + get_spec_layer_idx_from_weight_name) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + +from vllm_npu import envs +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.models.layers.sfa import Indexer +from vllm_npu.ops.weight_prefetch import maybe_npu_prefetch +from vllm_npu.quantization.quant_config import AscendLinearMethod +from vllm_npu.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE +from vllm_npu.torchair.quantization.torchair_w8a8_dynamic import \ + TorchairAscendW8A8DynamicLinearMethod +from vllm_npu.utils import dispose_tensor, oproj_tp_enable + + +class TorchairDeepseekV2SiluAndMul(SiluAndMul): + + def __init__(self, + *, + weight_scale: Optional[Callable[[], torch.Tensor]] = None): + super().__init__() + self.weight_scale = weight_scale + + def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor, + torch.Tensor]]): + if isinstance(x, tuple): + assert self.weight_scale is not None + # For AscendW8A8DynamicLinearMethod: + # a dynamic scale is passed along with the quantized value. + quantized_x, dynamic_scale = x + return torch_npu.npu_dequant_swiglu_quant( + x=quantized_x, + weight_scale=self.weight_scale(), + activation_scale=dynamic_scale, + activate_left=True, + quant_mode=1) + else: + return super().forward_oot(x) + + +class TorchairDeepseekV2MergedReplicatedLinear(ReplicatedLinear): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.output_sizes = output_sizes + super().__init__(input_size, + sum(output_sizes), + bias=bias, + quant_config=quant_config, + prefix=prefix) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, loaded_shard_id: int): + # With no support for GGUF format yet. + assert not getattr(param, "is_gguf_weight", False) + assert not getattr(param, "is_gguf_weight_type", False) + + assert loaded_shard_id < len(self.output_sizes) + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + shard = param.data.narrow(param.output_dim, shard_offset, shard_size) + + assert shard.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter shard of id {loaded_shard_id} size {shard.size()}" + ) + shard.copy_(loaded_weight) + + +class TorchairDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True, + is_force_scatter=False + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + forward_context = get_forward_context() + if self.reduce_results and self.tp_size > 1: + num_tokens = output_parallel.shape[0] + if is_force_scatter and num_tokens % self.tp_size: + output_parallel = nn.functional.pad( + output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) + if is_force_scatter or (not forward_context.with_prefill + and output_parallel.shape[0] % self.tp_size + == 0): + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) + else: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class TorchairDeepseekV2RowParallelLinear(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True, + is_force_scatter=False + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class TorchairDeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + force_replicate: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + if not force_replicate: + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + else: + self.gate_up_proj = TorchairDeepseekV2MergedReplicatedLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = ReplicatedLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + + quant_method = self.gate_up_proj.quant_method + if isinstance(quant_method, UnquantizedLinearMethod): + self.act_fn = TorchairDeepseekV2SiluAndMul() + elif (isinstance(quant_method, AscendLinearMethod) + and isinstance(quant_method.quant_method, + TorchairAscendW8A8DynamicLinearMethod)): + # TODO(sdmyzlp): Currently preserved as before: + # 1. The only quantization supported for silu is W8A8Dynamic + # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16 + # + # Maybe one can implement a better and more general configuration + # scheme, e.g. by somehow passing around the tweaked `quant_config` + self.act_fn = TorchairDeepseekV2SiluAndMul( + # Use lazy binding, for `weight_scale_fp32` is accessible + # only after `process_weights_after_loading`. + weight_scale=lambda: self.gate_up_proj.weight_scale_fp32) + # To be consumed by AscendW8A8DynamicLinearMethod.apply() + self.gate_up_proj._ascend_quant_config = { + "output_dtype": torch.int32, + "pertoken_scale": False, + "return_scale": True, + } + self.down_proj._ascend_quant_config = { + "output_dtype": torch.bfloat16, + "pertoken_scale": True, + "return_scale": False, + } + else: + raise NotImplementedError( + f"Quantization with [{type(quant_method)}] is NOT supported") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class TorchairDeepseekV2MoE(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.multistream_overlap_shared_expert = \ + ascend_config.multistream_overlap_shared_expert and \ + self.torchair_graph_enabled + + self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel + self.params_dtype = torch.float32 if self.enable_super_kernel else \ + torch.get_default_dtype() + # Converting gate weight to fp32 is to adapt to the super kernel feature. + # Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add. + # In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem, + # modifications will be made in the initialization stage. + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + params_dtype=self.params_dtype, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=self.params_dtype)) + else: + self.gate.e_score_correction_bias = None + + self.experts = TorchairAscendFusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + self.all_reduce_merge = self.experts.all_reduce_merge + reduce_results = not self.all_reduce_merge + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.shared_experts = TorchairDeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=reduce_results, + force_replicate=self.multistream_overlap_shared_expert + or enable_shared_expert_dp, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None # type: ignore + TorchairDeepseekV2MoE.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + self.rm_router_logits = self.experts.rm_router_logits + + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False) -> torch.Tensor: + + forward_context = get_forward_context() + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits = None + if not self.rm_router_logits and not self.multistream_overlap_shared_expert: + router_logits, _ = self.gate(hidden_states) + + experts_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=TorchairDeepseekV2MoE.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=self.shared_experts, + gate=self.gate, + replace_allreduce=replace_allreduce) + + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + experts_hidden_states[1]) + if self.all_reduce_merge: + # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states + + +class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer=None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_mla = \ + ascend_config.torchair_graph_config.enable_multistream_mla + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + + if oproj_tp_enable(): + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + elif (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.multistream_overlap_shared_expert + or self.enable_shared_expert_dp)): + self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + else: + self.o_proj = TorchairDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + enable_multistream_mla = (self.enable_multistream_mla + and attn_metadata is not None + and not forward_context.with_prefill + and attn_metadata.num_decodes > 0) + forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} + if self.q_lora_rank is not None: + maybe_npu_prefetch(self.q_a_proj.weight, + hidden_states, + enabled=enable_multistream_mla) + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + forward_kwargs['ckq'] = ckq + else: + hidden_states_or_q_c = hidden_states + if self.torchair_graph_enabled: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + output = output.view(-1, output_shape[-1]) + return output + else: + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + hidden_states_or_q_c = get_tp_group().all_gather( + hidden_states_or_q_c, 0) + kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + + kv_c, k_pe = kv_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + num_tokens = hidden_states_or_q_c.shape[0] + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=output_shape) + + +class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer=None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + if (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.multistream_overlap_shared_expert + or self.enable_shared_expert_dp)): + self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + else: + self.o_proj = TorchairDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + self.sfa_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sparse=True, + # SFA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + indexer=self.indexer, + decoder_layer=decoder_layer, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + if not self.torchair_graph_enabled: + if forward_context.attn_metadata is not None and isinstance( + forward_context.attn_metadata, dict): + attn_metadata = next( + iter(forward_context.attn_metadata.values()), None) + else: + attn_metadata = forward_context.attn_metadata + if kv_cache is None: + kv_cache = self.sfa_attn.kv_cache[ + forward_context.virtual_engine] + + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + # if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # # Simulate all gather to calculate output shape + # num_tokens = num_tokens * self.tp_size + # need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace: + output_shape = hidden_states.shape + if self.enable_shared_expert_dp and ( + self.debug_layer_idx == self.first_k_dense_replace + or self.debug_layer_idx == self.layers): + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, + need_gather_q_kv, output) + output = output.view(-1, output_shape[-1]) + return output + + +class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + self.layers = config.num_hidden_layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + ascend_config = get_ascend_config() + self.use_mla = False + self.use_sparse = False + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + if hasattr(model_config.hf_config, "index_topk"): + attn_cls = TorchairDeepseekV2SFAAttention + self.use_sparse = True + else: + attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment] + self.use_mla = True + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + decoder_layer=self, + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = TorchairDeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.mla_moe_communication = ascend_config.multistream_overlap_shared_expert \ + and model_config.use_mla and self.tp_size > 1 + else: + self.mlp = TorchairDeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.mla_moe_communication = False + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + self.first_k_dense_replace = config.first_k_dense_replace + self.tp_group = get_tp_group().device_group + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False, + ) -> torch.Tensor: + # Self Attention + forward_context = get_forward_context() + if attn_metadata is not None: + decoding_condition_met = ( + not attn_metadata.is_prefill if self.use_sparse else + not forward_context.with_prefill if self.use_mla else False) + mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce + else: + mla_moe_communication = False + + if (envs.vllm_npu_ENABLE_MLAPO + and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention) + and attn_metadata is not None + and not forward_context.with_prefill): + if residual is not None: + hidden_states = hidden_states + residual + residual = hidden_states + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) + if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if mla_moe_communication and residual.shape[0] != hidden_states.shape[ + 0]: + chunk_hidden_states = torch.tensor_split(residual, + self.tp_size, + dim=0) + residual = chunk_hidden_states[self.tp_rank] + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + tp_size = get_tensor_model_parallel_world_size() + if self.enable_shared_expert_dp and ( + self.layer_idx == self.first_k_dense_replace + or self.layer_idx == self.layers) and tp_size > 1: + num_tokens, _ = residual.shape + if num_tokens % tp_size: + residual = nn.functional.pad(residual, + (0, 0, 0, -num_tokens % tp_size)) + chunk_residual = torch.tensor_split(residual, tp_size, dim=0) + tp_rank = get_tensor_model_parallel_rank() + residual = chunk_residual[tp_rank] + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, TorchairDeepseekV2MoE): + hidden_states = self.mlp(hidden_states, + attn_metadata, + replace_allreduce=mla_moe_communication) + else: + hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, TorchairDeepseekV2MLP + ) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + if mla_moe_communication and self.layer_idx >= self.layers - 1: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) + residual = tensor_model_parallel_all_gather(residual, dim=0) + + # for last layer of main model and mtp layer. + if self.enable_shared_expert_dp and self.layer_idx >= ( + self.layers - 1) and tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + residual = get_tp_group().all_gather(residual, 0) + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values()), None) + if attn_metadata is not None: + num_tokens = attn_metadata.num_actual_tokens + else: + num_tokens = hidden_states.shape[0] + + if num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:num_tokens] + residual = residual[:num_tokens] + + return hidden_states, residual + + +class TorchairDeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.tp_size = get_tensor_model_parallel_world_size() + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: TorchairDeepseekV2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, + replace_allreduce=replace_allreduce) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class TorchairDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.num_dense_layers = self.config.first_k_dense_replace + self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers + self.quant_config = quant_config + self.model = TorchairDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # NOTE: This `load_weights` is mainly copied from + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 + # to fix CI, and it is different from the implementation in main + # TODO: support eplb style load_weights + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = TorchairAscendFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "module" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm_npu/torchair/models/torchair_deepseek_v3.py b/vllm_npu/torchair/models/torchair_deepseek_v3.py new file mode 100644 index 0000000..a10a722 --- /dev/null +++ b/vllm_npu/torchair/models/torchair_deepseek_v3.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm_npu.torchair.models.torchair_deepseek_v2 import \ + TorchairDeepseekV2ForCausalLM + + +class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM): + pass diff --git a/vllm_npu/torchair/models/torchair_pangu_moe.py b/vllm_npu/torchair/models/torchair_pangu_moe.py new file mode 100644 index 0000000..4eb13bc --- /dev/null +++ b/vllm_npu/torchair/models/torchair_pangu_moe.py @@ -0,0 +1,1118 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch_npu +from torch import nn +from torch.nn import Parameter +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (divide, get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group, get_world_group) +from vllm.forward_context import get_forward_context +from vllm.logger import logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.v1.sample.sampler import Sampler + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_310p + +_ROUTER_SCALE = None + + +def use_h2p(): + # only use H2P when dp_size > 1. + if get_dp_group().world_size > 1: + return True + return False + + +# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp). +class CustomMergedColumnParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + output_size = sum(output_sizes) + self.output_sizes = output_sizes + self.tp_size = get_tp_group().world_size + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + loaded_shard_id: int): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + assert loaded_shard_id < len(self.output_sizes) + + tp_rank = get_tp_group().rank_in_group + tp_size = get_tp_group().world_size + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp) +# and detach communication to enable customized communication algorithms(e.g., H2P). +class CustomRowParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + group=None, + ): + # Divide the weight matrix along the first dimension. + self.group = group if group is not None else get_tp_group() + self.tp_rank = self.group.rank_in_group + self.tp_size = self.group.world_size + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = self.group.rank_in_group + input_dim = getattr(param, "input_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + is_sharded_weight = is_sharded_weight + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + input_parallel = input_ + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.quant_method.apply(self, input_parallel, bias=bias_) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class PanguProMoEMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + if not use_h2p(): + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + else: + self.gate_up_proj = CustomMergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = CustomRowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def topk_wrapper(num_voted_experts): + + def pangu_group8_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool = False, + num_expert_group: int = 0, + topk_group: int = 0, + global_num_experts: int = 0, + ): + scores = F.softmax(gating_output, dim=1) + num_tokens = scores.shape[0] + router_scale = _ROUTER_SCALE.squeeze( # type: ignore + ) + # TODO: support disable expert parallel + ep_size = get_ep_group().world_size + local_num_experts = global_num_experts // ep_size + local_num_group = topk // ep_size + experts_per_group = global_num_experts // topk + local_group_start = get_ep_group().rank_in_group * local_num_experts + local_group_end = (get_ep_group().rank_in_group + + 1) * local_num_experts + scores = F.softmax(gating_output, dim=1) + scores = scores[..., local_group_start:local_group_end] + + router_weights = router_scale[local_group_start:local_group_end] + + if num_voted_experts == 8: + # use original topk + topk_weights, topk_ids = torch.max(scores.view( + scores.shape[0], local_num_group, -1), + dim=-1) + bias = torch.arange(0, + local_num_experts, + experts_per_group, + device=scores.device, + dtype=torch.int32).unsqueeze(0) + topk_ids = topk_ids.to(torch.int32) + bias + + else: + group_expert_indices = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device).view( + 1, 1, -1) + group_expert_offset = (torch.arange( + local_num_group, dtype=torch.int32, device=scores.device) * + experts_per_group).unsqueeze(0) + expert_index_range = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device) + + scores_grouped = scores.view(num_tokens, local_num_group, + experts_per_group) + best_expert_idx = torch.argmax(scores_grouped, + dim=2) # (num_tokens, num_groups) + vote_mask = (best_expert_idx.unsqueeze(-1).to( + torch.int32) == group_expert_indices) + + expert_vote_freq = vote_mask.sum(dim=0) + + sorted_indices = torch.argsort(expert_vote_freq, + dim=1, + descending=True).to(torch.int32) + topk_experts = sorted_indices[:, :num_voted_experts] + keep_mask = (( + topk_experts.unsqueeze(-1) == expert_index_range).any( + dim=1)).unsqueeze(0) + + masked_scores = torch.where(keep_mask, scores_grouped, 0) + + topk_weights, best_pos_in_group = masked_scores.max(dim=2) + best_pos_in_group = best_pos_in_group.to(torch.int32) + topk_ids = (best_pos_in_group + group_expert_offset).to( + torch.int32) + + flatten_topk_ids = topk_ids.view(-1) + router_weights = router_weights.index_select(0, flatten_topk_ids).view( + topk_ids.shape) + topk_weights *= router_weights + + return topk_weights, topk_ids + + return pangu_group8_topk + + +class PanguProMoESparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.num_experts_per_tok = config.num_experts_per_tok + self.router_scale = torch.nn.Parameter( + torch.ones((1, self.num_experts))) + + # on 300I Duo platform, we find that num_voted_experts set to 5 achieves + # good performance without sacrifice too much accuracy. for other platform, + # this is set to 8 to use original pangu grouped topk. + num_voted_experts = 5 if is_310p() else 8 + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + quant_config=quant_config, + custom_routing_function=topk_wrapper(num_voted_experts), + prefix=f"{prefix}.experts", + ) + self.use_ep = self.experts.use_ep + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = PanguProMoEMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None # type: ignore + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + global _ROUTER_SCALE + _ROUTER_SCALE = self.router_scale + + # TODO(angazenn): Does not support MC2 currently + get_forward_context().moe_comm_method_name = "allgathercommimpl" + + if not use_h2p(): + final_hidden_states = self.experts.forward_impl( + hidden_states=hidden_states, router_logits=router_logits) + else: + # TODO: when using h2p, we have to skip communication in vLLM + # native FusedMoE. here we need to design a better FusedMoE + # (maybe using AscendFusedMoE) to enable these different + # communication schema. + final_hidden_states = self.experts.quant_method.apply( + layer=self.experts, + x=hidden_states, + router_logits=router_logits, + top_k=self.experts.top_k, + renormalize=False, + use_grouped_topk=False, + global_num_experts=self.experts.global_num_experts, + expert_map=self.experts.expert_map, + custom_routing_function=self.experts.custom_routing_function, + apply_router_weight_on_input=self.experts. + apply_router_weight_on_input) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if not use_h2p(): + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class PanguProMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + if use_h2p(): + self.o_proj = CustomRowParallelLinear(self.total_num_heads * + self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + group=get_tp_group()) + else: + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + if self.torchair_graph_enabled: + forward_kwargs = {'trace_flag': False} + output_shape = q.shape + attn_output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = attn_output + attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache, + attn_metadata, + **forward_kwargs) + else: + attn_output = self.attn(q, k, v) + + output, _ = self.o_proj(attn_output) + return output + + +class PanguProMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = PanguProMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): + self.mlp = PanguProMoESparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = PanguProMoEMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + h2p_unpad_idx: Optional[torch.Tensor] = None, + h2p_pad_idx: Optional[torch.Tensor] = None, + is_start_layer: Optional[bool] = False, + ) -> torch.Tensor: + need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \ + and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0] + tp_size = get_tp_group().world_size + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if use_h2p(): + if is_start_layer: + if need_h2p_pad: + residual = residual.index_select(dim=0, index=h2p_pad_idx) + residual = torch.tensor_split( + residual, tp_size)[get_tp_group().rank_in_group] + else: + if tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if need_h2p_pad: + hidden_states = hidden_states.index_select( + dim=0, index=h2p_unpad_idx) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if use_h2p(): + if need_h2p_pad: + hidden_states = hidden_states.index_select(dim=0, + index=h2p_pad_idx) + if tp_size > 1: + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_tp_group().device_group) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if use_h2p(): + all_rank_group = get_world_group().device_group + output_size = (hidden_states.shape[0] * + get_world_group().world_size, + hidden_states.shape[1]) + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=hidden_states.dtype, + device=hidden_states.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + hidden_states, + group=all_rank_group) + hidden_states = output_tensor + + hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata) + + if use_h2p(): + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_world_group().device_group) + + return hidden_states, residual + + +@support_torch_compile +class PanguProMoEModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PanguProMoEDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + if use_h2p(): + # calculate necessary padding/unpadding idx before model forward. + + # the attn_metadata will be passed directly when use torchair. + # if attn_meatadata is not passed, we try to get it from forward_context. + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + + max_tokens_across_dp = get_forward_context().max_tokens_across_dp + + tp_size = get_tp_group().world_size + # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. + # we need pad it before if the shape can't be divided by group size. + # for h2p, we need pad it so that it can be divided by tp_size. + h2p_padded_len = ( + tp_size - (max_tokens_across_dp % tp_size) + ) % tp_size + max_tokens_across_dp - hidden_states.shape[0] + h2p_unpad_idx = torch.arange(hidden_states.shape[0], + device=hidden_states.device, + dtype=torch.int32) + h2p_pad_idx = torch.cat([ + h2p_unpad_idx, + torch.zeros(h2p_padded_len, + dtype=torch.int32, + device=hidden_states.device) + ]) + else: + h2p_unpad_idx = None + h2p_pad_idx = None + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, h2p_unpad_idx, h2p_pad_idx, + i == self.start_layer) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + if use_h2p(): + if get_tp_group().world_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]: + hidden_states = hidden_states.index_select(dim=0, + index=h2p_unpad_idx) + return hidden_states + + +class PanguProMoEForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = PanguProMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, # type: ignore + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata, # type: ignore + ): + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) # from model + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + # ======================================================= + # BF: add this to load with less layers + if 'layers' in name: + layer_idx = int(name.split('layers.')[-1].split('.')[0]) + if layer_idx >= self.model.end_layer: + continue + + if "rotary_emb.inv_freq" in name: + continue + + if "module" in name: + continue + + if name.endswith('kv_cache_offset'): + continue + + if name.endswith("k_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace( + "k_proj.kv_cache_scale", "attn.key_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, + tp_size, + dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if name.endswith("v_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace( + "v_proj.kv_cache_scale", "attn.value_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, + tp_size, + dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + # breakpoint() + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + # breakpoint() + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + if is_310p() and "head" in name: + # on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than + # ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented + # by linear, we manually cast the format here. + param.data = torch_npu.npu_format_cast(param.data, + ACL_FORMAT_FRACTAL_NZ) + return loaded_params diff --git a/vllm_npu/torchair/ops/__init__.py b/vllm_npu/torchair/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/torchair/ops/sequence_parallel.py b/vllm_npu/torchair/ops/sequence_parallel.py new file mode 100644 index 0000000..d055437 --- /dev/null +++ b/vllm_npu/torchair/ops/sequence_parallel.py @@ -0,0 +1,120 @@ +import torch +from torch.nn import functional as F +from vllm.distributed import (get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context + +from vllm_npu.platform import NPUPlatform + + +class MetadataForPadding: + + def __init__(self, + padding_flag=False, + lengths_sum_padding=0, + lengths_sum_unpadding=0, + pad_size=0, + not_dummy_and_is_prefill=False): + self.padding_flag = padding_flag + self.not_dummy_and_is_prefill = not_dummy_and_is_prefill + + self.lengths_sum_padding = lengths_sum_padding + self.lengths_sum_unpadding = lengths_sum_unpadding + self.pad_size = pad_size + + self.tp_size = get_tp_group().world_size + self.tp_rank_in_group = get_tp_group().rank_in_group + + assert self.lengths_sum_padding % self.tp_size == 0 + self.slice_size = self.lengths_sum_padding // self.tp_size + + self.mc2_mask = torch.zeros( + self.lengths_sum_padding, + dtype=torch.bool, + device=NPUPlatform.device_type, + ) + self.mc2_mask[:lengths_sum_unpadding] = True + + def padding_aligned_reduce_scatter(self, + data: torch.Tensor) -> torch.Tensor: + if self.padding_flag: + pad_size = self.pad_size + padded_data = F.pad(data, (0, 0, 0, pad_size)) + else: + padded_data = data + padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter( + padded_data, 0) + + return padded_data_reduce_scatter + + def allgather_unpadding_aligned(self, + padded_data: torch.Tensor) -> torch.Tensor: + padded_data_allgather = tensor_model_parallel_all_gather( + padded_data, 0) + if self.padding_flag: + lengths_sum_unpadding = self.lengths_sum_unpadding + unpadding_data = padded_data_allgather[:lengths_sum_unpadding] + else: + unpadding_data = padded_data_allgather + return unpadding_data + + def padding_slice(self, data: torch.Tensor) -> torch.Tensor: + + padded_data = F.pad(data, (0, 0, 0, self.pad_size)) + start = self.tp_rank_in_group * self.slice_size + end = start + self.slice_size + slice_data = padded_data[start:end] + + return slice_data + + def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor: + if self.padding_flag: + pad_size = self.pad_size + padded_data = F.pad(data, (0, 0, 0, pad_size)) + else: + padded_data = data + # padded_data = data + padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0) + + padded_data_reduce_scatter = padded_data[self.tp_rank_in_group] + + return padded_data_reduce_scatter + + +def init_metadata_for_sp(input_ids, enable_sequence_parallelism): + if not enable_sequence_parallelism: + return MetadataForPadding(padding_flag=False, + not_dummy_and_is_prefill=False) + + is_perifll = 0 + attn_metadata = get_forward_context().attn_metadata + tp_size = get_tensor_model_parallel_world_size() + if attn_metadata is not None: + if hasattr(attn_metadata, + 'is_only_prefill') and attn_metadata.is_only_prefill: + is_perifll = 1 + if hasattr(attn_metadata, + 'num_prefills') and attn_metadata.num_prefills > 0: + is_perifll = 1 + + if is_perifll: + lengths_sum_unpadding = input_ids.shape[0] + lengths_sum_padding = ( + (lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size + if lengths_sum_unpadding == lengths_sum_padding: + padding_flag = False + else: + padding_flag = True + pad_size = lengths_sum_padding - lengths_sum_unpadding + _metadata_for_padding = MetadataForPadding( + lengths_sum_unpadding=lengths_sum_unpadding, + lengths_sum_padding=lengths_sum_padding, + padding_flag=padding_flag, + pad_size=pad_size, + not_dummy_and_is_prefill=True) + + return _metadata_for_padding + + return MetadataForPadding(padding_flag=False, + not_dummy_and_is_prefill=False) diff --git a/vllm_npu/torchair/ops/shared_weight_layer.py b/vllm_npu/torchair/ops/shared_weight_layer.py new file mode 100644 index 0000000..6ab29af --- /dev/null +++ b/vllm_npu/torchair/ops/shared_weight_layer.py @@ -0,0 +1,245 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.linear import LinearBase + + +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty([], device=x.device, dtype=x.dtype)) + + +@dataclass +class LayerMetadata: + """Metadata for a layer. + """ + layer: Optional[LinearBase] # The layer object. + post_method: Callable[[ + torch.nn.Module + ], None] # The `process_weights_after_loading` method from the quant method. + weight: torch.Tensor # The weight tensor. + window_idx: int # The index of the window. + + +@dataclass +class SharedWindowMetadata: + """Metadata for a shared window. + """ + weight: torch.Tensor # The weight tensor to be shared by layers. + data_layer_idx: int # The index of the layer this window's weight is equal to. + work: Optional[torch.distributed.Work] # The asynchronous broadcast work. + + +@dataclass +class SeriesMetadata: + """Metadata for a weight shared series. + """ + group: GroupCoordinator + start_layer: int + end_layer: int + num_layers: int + prefetch_step: int + dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. + layers: list[LayerMetadata] + shared_windows: list[ + SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. + window_offset: int # The index of the window for the next coming layer. + + def is_source(self, layer_idx) -> bool: + return layer_idx % self.group.world_size == self.group.rank_in_group + + def post_process_after_loading(self): + # This method only needs to be called once per series. + if self.shared_windows: + return + for layer_idx in range(self.start_layer, self.end_layer): + layer = self.layers[layer_idx - self.start_layer] + is_source = self.is_source(layer_idx) + # If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. + if not is_source: + layer.weight.set_(torch.empty_like(self.dummy_weight)) + # Broadcast to get the true weight. + dist.broadcast(layer.weight, + src=self.group.ranks[layer_idx % + self.group.world_size], + group=self.group.device_group) + assert layer.layer is not None + # Call `process_weights_after_loading` from the quant method. + layer.post_method(layer.layer) + step = layer_idx - self.start_layer + if step < self.prefetch_step: + # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. + self.shared_windows.append( + SharedWindowMetadata( + weight=layer.weight.clone().detach(), + data_layer_idx=layer_idx, + work=None, + )) + layer.window_idx = step + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not is_source: + layer.weight.set_(self.shared_windows[-1].weight) + else: + # Build one more window for prefetch. The weight is useless, so just keep the shape. + if step == self.prefetch_step: + self.shared_windows.append( + SharedWindowMetadata( + weight=torch.empty_like(layer.weight), + data_layer_idx=-1, + work=None, + )) + # When the layer not intended to be stored in this device, dispose the tensor. + if not is_source: + dispose_tensor(layer.weight) + + dispose_tensor(self.dummy_weight) + + def reach_layer(self, layer_idx: int): + # The index of the layer to be prefetched. + next_layer_idx = (layer_idx + self.prefetch_step + ) % self.num_layers + self.start_layer + next_layer = self.layers[next_layer_idx - self.start_layer] + # The index of the window to store the weight for the coming layer. + next_layer.window_idx = self.window_offset + window = self.shared_windows[next_layer.window_idx] + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not self.is_source(next_layer_idx): + next_layer.weight.set_(window.weight) + # Update `window_offset` by rolling one step. + self.window_offset = (self.window_offset + 1) % (self.prefetch_step + + 1) + assert window.data_layer_idx != next_layer_idx + window.data_layer_idx = next_layer_idx + # Start asynchronous broadcast work. + window.work = dist.broadcast( + next_layer.weight, + src=self.group.ranks[next_layer_idx % self.group.world_size], + group=self.group.device_group, + async_op=True) + + def wait_weight(self, layer_idx: int): + # Find the asynchronous broadcast work and wait for it. + assert self.shared_windows + window = self.shared_windows[self.layers[layer_idx - + self.start_layer].window_idx] + # Make sure the data in the corresponding shared window is for the current layer. + assert window.data_layer_idx == layer_idx + if window.work is not None: + window.work.wait() + window.work = None + + +@dataclass +class LayerExternalMetadata: + """External metadata for a layer. + """ + series: SeriesMetadata + layer_idx: int + + +_series_dict: dict[str, SeriesMetadata] = {} + +_layer_external_dict: dict[int, LayerExternalMetadata] = {} + + +def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, + layer_idx: int) -> Callable: + + def wrapped_forward(*args, **kwargs): + # Wait for the weight. + series.wait_weight(layer_idx) + return forward(*args, **kwargs) + + return wrapped_forward + + +""" +Register linear layers into a shared storage series. + +In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. + +After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. + +During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. + +Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: +- total_layers = end_layer - start_layer +- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer + +To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. + +Arguments: + series_name: This name identifies which series this layer belongs to. + group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. + start_layer: The index of the first layer in the series (inclusive). + end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). + layer_idx: The index of the current layer. + layer: The linear layer object to register. + prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. +""" + + +def register_layer_to_shared_weight_series( + series_name: str, + group: GroupCoordinator, + start_layer: int, + end_layer: int, + layer_idx: int, + layer: LinearBase, + prefetch_step: int = 1, +): + global _series_dict + if series_name not in _series_dict: + num_layers = end_layer - start_layer + assert num_layers > 0 + assert prefetch_step >= 0 and prefetch_step <= num_layers - 2 + _series_dict[series_name] = SeriesMetadata( + group=group, + start_layer=start_layer, + end_layer=end_layer, + num_layers=num_layers, + prefetch_step=prefetch_step, + dummy_weight=torch.empty_like(layer.weight), + layers=[ + LayerMetadata( + layer=None, + post_method=lambda layer: None, + weight=torch.empty([]), + window_idx=-1, + ) for _ in range(num_layers) + ], + shared_windows=[], + window_offset=prefetch_step, + ) + series = _series_dict[series_name] + assert layer.quant_method is not None + series.layers[layer_idx - start_layer] = LayerMetadata( + layer=layer, + post_method=layer.quant_method.process_weights_after_loading, + weight=layer.weight, + window_idx=-1, + ) + # Discard the original `process_weights_after_loading` method such that it won't be called by others. + layer.quant_method.process_weights_after_loading = lambda layer: None + # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. + if not series.is_source(layer_idx): + dispose_tensor(layer.weight) + layer.weight.weight_loader = lambda *args, **kwargs: None + layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx) + global _layer_external_dict + _layer_external_dict[id(layer)] = LayerExternalMetadata( + series=series, + layer_idx=layer_idx, + ) + + +def post_process_after_loading_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.post_process_after_loading() + + +def reach_layer_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.reach_layer(ext.layer_idx) diff --git a/vllm_npu/torchair/ops/torchair_activation.py b/vllm_npu/torchair/ops/torchair_activation.py new file mode 100644 index 0000000..637df2d --- /dev/null +++ b/vllm_npu/torchair/ops/torchair_activation.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch + + +def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: + """AscendSiluAndMul forward in torchair mode. + + The key difference from the original implementation is the removal of operators + from the torch.ops.vllm class, as these operators only function in non-torchair + modes. Adding them back would cause the graph compilation to fail. + """ + + import torch_npu + + from vllm_npu.utils import is_310p + + if is_310p(): + out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) + else: + out = torch_npu.npu_swiglu(x) + return out diff --git a/vllm_npu/torchair/ops/torchair_fused_moe.py b/vllm_npu/torchair/ops/torchair_fused_moe.py new file mode 100644 index 0000000..a0ffb2f --- /dev/null +++ b/vllm_npu/torchair/ops/torchair_fused_moe.py @@ -0,0 +1,1409 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/kernels/test_moe.py + +import os +from typing import Any, Callable, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch_npu +from torch import nn +from vllm.config import get_current_vllm_config +from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) +from vllm.forward_context import get_forward_context +from vllm.logger import logger +from vllm.model_executor.layers.fused_moe.config import \ + FusedMoEConfig # isort: skip +from vllm.model_executor.layers.fused_moe.config import \ + FusedMoEParallelConfig # isort: skip +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map, + get_compressed_expert_map) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.ascend_forward_context import FusedMoEState +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.eplb.core.eplb_utils import (determine_default_expert_map, + determine_default_log2phy_map) +from vllm_npu.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_npu.quantization.quant_config import AscendFusedMoEMethod +from vllm_npu.torchair.ops.sequence_parallel import MetadataForPadding +from vllm_npu.torchair.utils import (get_all_reduce_merge_state, + get_rm_router_logits_state, + npu_stream_switch, npu_wait_tensor, + super_kernel) +from vllm_npu.utils import (AscendSocVersion, dispose_tensor, + get_ascend_soc_version, is_310p, + is_hierarchical_communication_enabled) + + +def torchair_fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + moe_parallel_config: FusedMoEParallelConfig, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, + shared_experts: Optional[Any] = None, + is_torchair: bool = False, + mc2_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + quant_mode = 0 + ep_rank_id = moe_parallel_config.ep_rank + ep_world_size = moe_parallel_config.ep_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + need_expert_scale = is_hierarchical_communication_enabled() + + enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") + + moe_expert_num = len(expert_map) + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) + if need_expert_scale: + stage1_kwargs.update({ + "expert_scales": topk_weights.to(torch.float32), + }) + + kwargs_mc2.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ + ep_recv_counts, _, expand_scales = output[0:7] + + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(hidden_states, topk_weights) + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + npu_wait_tensor(shared_gate_up, expand_x) + shared_act = shared_experts.act_fn(shared_gate_up) + + w1 = w1.transpose(1, 2) + + group_list = expert_token_nums.to(torch.int64) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[w1], + split_item=2, + # 1 means count mode, to avoid cumulative operation of the group list + group_list_type=1, + group_type=0, + group_list=group_list, + )[0] + + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) + + w2 = w2.transpose(1, 2) + down_out_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=1, + group_type=0, + group_list=group_list, + )[0] + + # moeCombine + kwargs_mc2 = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + tp_recv_counts = output[5] + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + "expand_scales": expand_scales, + } + if enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": assist_info_for_combine, + }) + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) + kwargs_mc2.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) + + if shared_experts is None: + return hidden_states + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act, down_out_list) + shared_hidden_states, _ = shared_experts.down_proj(shared_act) + return hidden_states, shared_hidden_states + + +def torchair_apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, +) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + w1 = w1.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + + hidden_states = torch_npu.npu_swiglu(hidden_states) + + w2 = w2.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + + return hidden_states + + +# currently expert parallelism implemented with all2all +# is under-optimized. +def torchair_fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + device = hidden_states.device + + if expert_map is not None: + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, + -1).sum(-1) + + gather_sizes = torch.empty_like(scatter_sizes) + dist.all_to_all_single(gather_sizes, + scatter_sizes, + group=ep_group.device_group) + scatter_size_list = scatter_sizes.cpu().tolist() + gather_size_list = gather_sizes.cpu().tolist() + + expanded_expert_idx = expanded_expert_idx % local_num_experts + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + scatter_size_list, + gather_size_list) + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, + scatter_size_list, + gather_size_list) + + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + + hidden_states = hidden_states[sorted_idx] + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + + w1 = w1.transpose(1, 2) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + )[0] + + hidden_states = torch_npu.npu_swiglu(gate_up_out_list) + + w2 = w2.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + )[0] + + if expert_map is not None: + resorted_idx = torch.argsort(sorted_idx) + hidden_states = hidden_states[resorted_idx] + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + gather_size_list, + scatter_size_list) + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +def torchair_fused_experts_moge( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + moe_parallel_config: FusedMoEParallelConfig, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + """ + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). + + Returns: + hidden_states: Hidden states after routing. + """ + ep_size = moe_parallel_config.ep_size + local_num_experts = global_num_experts // ep_size + local_num_group = top_k // ep_size + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + + bsz, _ = hidden_states.shape + flatten_topk_ids = topk_ids.view(-1) + sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) + sorted_topk_ids = sorted_topk_ids.to(torch.int32) + sorted_hidden_states = hidden_states.index_select( + 0, sorted_topk_ids // local_num_group) + + experts_id = torch.arange(0, + local_num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( + torch.float32).sum(0) + topk_scales = topk_weights.view(-1).index_select( + 0, sorted_topk_ids).unsqueeze(-1) + group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) + + w1 = w1.transpose(1, 2) + gate_up_out = torch_npu.npu_grouped_matmul( + x=[sorted_hidden_states], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=group_list, + )[0] + + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out *= topk_scales + + w2 = w2.transpose(1, 2) + down_out_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=group_list, + )[0] + + unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) + unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) + final_hidden_states = unsorted_hidden_states.reshape( + bsz, top_k // ep_size, -1).sum(1) + + return final_hidden_states + + +def torchair_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, + max_num_tokens: Optional[int] = None, +) -> torch.Tensor: + """ + Fused experts with top-k routing. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). + + Returns: + hidden_states: Hidden states after routing. + """ + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + """ + # if torch.distributed.get_rank() == 0: + # print(w1.shape) + # print(hidden_states.shape) + + original_shape = hidden_states.shape + # assert len(original_shape) == 2 + + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + dtype = hidden_states.dtype + device = hidden_states.device + # assert dtype in [torch.float32, torch.float16, torch.bfloat16 + # ], "Only float32, float16, and bfloat16 are supported" + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + filtered_weights = torch.where( + mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) + filtered_experts = torch.where( + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts.view(torch.float32)) + sorted_token_indices = token_indices[sort_indices] + sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + token_counts = token_counts[:num_experts] + expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) + + # Rearrange hidden_states + sorted_hidden_states = hidden_states[sorted_token_indices] + else: + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) + active_num = max_num_tokens if max_num_tokens is not None else num_tokens + sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=active_num) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + + w1 = w1.transpose(1, 2) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[sorted_hidden_states], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + )[0] + + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) + + w2 = w2.transpose(1, 2) + down_out_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + )[0] + + if expert_map is not None: + weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros(*original_shape, + device=hidden_states.device, + dtype=dtype) + + # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # This created multiple NaN and index_add_ will mix them up which harms accuracy + # remove this mask and filter after it being fixed + num_valid_tokens = mask.sum() + valid_token_mask = torch.arange( + 0, sorted_token_indices.shape[0], + device=device).unsqueeze(1) < num_valid_tokens + valid_output = torch.where( + valid_token_mask, weighted_down_out, + torch.zeros_like(weighted_down_out)).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, valid_output) + else: + scales = torch.ones_like( + topk_weights) if apply_router_weight_on_input else topk_weights + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + down_out_list, + skip1=None, + skip2=None, + bias=None, + scales=scales, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + return final_hidden_states + + +def torchair_native_grouped_topk( + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], +): + topk_group = 0 if topk_group is None else topk_group + num_expert_group = 0 if num_expert_group is None else num_expert_group + + num_token = topk_weights.shape[0] + grouped_weights = topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + topk_group_mask = torch.zeros_like(grouped_weights) + topk_group_mask.scatter_(1, topk_group_indices, 1) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) + + return topk_weights + + +def torchair_select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + + def _renormalize_topk_weights( + topk_weights: torch.Tensor, + renormalize: bool, + ): + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, + keepdim=True) + return topk_weights + + if scoring_func == "softmax": + # NOTE: vLLM use dtype=torch.float here + if not use_grouped_topk and custom_routing_function is None: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + x=router_logits, finished=None, k=top_k) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + return topk_weights, topk_ids + + topk_weights = router_logits.softmax(dim=-1) + elif scoring_func == "sigmoid": + topk_weights = router_logits.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_weights = topk_weights + topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) + + # TODO: Change to npu_group_topk when the latest CANN and NNAL is available + # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) + topk_weights = torchair_native_grouped_topk(topk_weights, + num_expert_group, + topk_group) + # TODO bfloat16 is not supported in torch.topk with ge graph. + if e_score_correction_bias is not None: + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_weights.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False) + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + return topk_weights, topk_ids + + if custom_routing_function is not None: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids + + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + topk_weights = topk_weights.to(hidden_states.dtype) + + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + topk_weights = _renormalize_topk_weights(topk_weights, renormalize) + + return topk_weights, topk_ids + + +class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + + def __init__(self, moe: FusedMoEConfig = None): + + super().__init__(moe=moe) + vllm_config = get_current_vllm_config() + + self.global_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_model_len = vllm_config.model_config.max_model_len + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = None + + def process_weights_after_loading(self, layer): + super(UnquantizedFusedMoEMethod, + self).process_weights_after_loading(layer) + layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w13_weight.data), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w2_weight.data), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = False, + enable_force_load_balance: bool = False, + shared_experts: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + + topk_weights = topk_weights.to(x.dtype) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All + + if fused_moe_state == FusedMoEState.MC2: + return torchair_fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + moe_parallel_config=self.moe.moe_parallel_config, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + mc2_mask=kwargs.get("mc2_mask", None)) + elif fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.NaiveMulticast + ]: + return torchair_fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + else: + return torchair_fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=get_ep_group()) + + +class TorchairAscendFusedMoE(FusedMoE): + + # The moe_counter parameter is required during the initialization of EPLB + # to identify the current layer index within the MOE model. + moe_counter = -1 + + def __init__( + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + ): + # TODO: This could not initialize FusedMoE baseclass, + # fixme and make __init__() of AscendFusedMoE more clear + super().__init__( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=reduce_results, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + quant_config=quant_config, + tp_size=tp_size, + ep_size=ep_size, + dp_size=dp_size, + prefix=prefix, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + ) + TorchairAscendFusedMoE.moe_counter += 1 + self.moe_instance_id = TorchairAscendFusedMoE.moe_counter + self.prefix = prefix + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + vllm_config = get_current_vllm_config() + + self.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size + if dp_size is not None else get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config) + + self.top_k = top_k + self.num_experts = num_experts + self.global_num_experts = num_experts + assert intermediate_size % self.tp_size == 0 + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.expert_map = None + self.activation = activation + self.log2phy = None + self.global_redundant_expert_num = 0 + + is_deepseek_v3_r1 = self.global_num_experts == 256 + self.rm_router_logits = get_rm_router_logits_state( + self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1) + self.all_reduce_merge = get_all_reduce_merge_state( + self.moe_parallel_config.ep_size, is_deepseek_v3_r1) + + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + self.expert_map_path = ascend_config.expert_map_path + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.global_num_experts = num_experts + self.global_redundant_expert_num + # static eplb initializing with expert_map_path + if self.expert_map_path and os.path.exists( + self.expert_map_path) and os.access(self.expert_map_path, + os.R_OK): + self.expert_load_balancer = ExpertLoadBalancer( + self.expert_map_path, self.global_num_experts) + self.expert_load_balancer.check_expert_map_tensor() + self.global_redundant_expert_num = ( + self.expert_load_balancer.get_global_redundant_expert_num()) + try: + self.local_num_experts, self.expert_map = ( + self.expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank)) + self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, self.ep_rank).npu() + except Exception as e: + logger.warning( + f"Init expert map of mtp/eagle when using sample.{e}") + self.local_num_experts, self.expert_map = determine_default_expert_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank).npu() + if self.expert_map is not None and isinstance( + self.expert_map, torch.Tensor): + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) + else: + # init moe. + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, self.ep_rank, self.global_num_experts) + # dynamic eplb initializing with not expert_map_path + if self.dynamic_eplb: + self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.local_num_experts, self.expert_map = determine_default_expert_map( + self.global_num_experts, self.ep_size, self.ep_rank, + self.global_redundant_expert_num) + self.log2phy = determine_default_log2phy_map( + self.global_num_experts, self.ep_size, self.ep_rank).npu() + if self.expert_map is not None and isinstance( + self.expert_map, torch.Tensor): + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, + dtype=torch.int64).npu() + + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.multistream_overlap_shared_expert = \ + ascend_config.multistream_overlap_shared_expert and \ + self.torchair_graph_enabled + self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") + self.moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, + ) + if quant_config is None: + self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( + self.moe) + else: + if quant_config.is_layer_skipped_ascend( + prefix, quant_config.packed_modules_mapping): + self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( + self.moe) + else: + self.quant_method = AscendFusedMoEMethod( + quant_config, prefix, quant_config.packed_modules_mapping) + + assert self.quant_method is not None + + self.moe_load = None + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + + moe_quant_params = { + "num_experts": local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.ep_group = get_ep_group() + # NOTE: self.tp_group is not expert_tp_group + self.tp_group = get_tp_group().device_group + self.quant_method.create_weights(layer=self, **moe_quant_params) + + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + return buffer + + def forward(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + enable_force_load_balance: bool = False, + top_k: Optional[int] = None, + shared_experts: Optional[Any] = None, + gate=None, + replace_allreduce: bool = False, + _metadata_for_padding: Optional[MetadataForPadding] = None): + + assert self.quant_method is not None + + if top_k: + real_top_k = top_k + else: + real_top_k = self.top_k + + num_tokens, hidden_size = hidden_states.shape + + forward_context = get_forward_context() + fused_moe_state = forward_context.fused_moe_state + mc2_mask = forward_context.mc2_mask + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All + # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. + quantized_x_for_share, dynamic_scale_for_share = None, None + from vllm_npu.torchair.quantization.torchair_w8a8_dynamic import \ + TorchairAscendW8A8DynamicFusedMoEMethod + running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2 + + if self.multistream_overlap_shared_expert: + with super_kernel(self.prefix, + "stream-fusion=1", + enabled=running_in_super_kernel): + if not self.rm_router_logits: + if self.enable_super_kernel: + router_logits, _ = gate(hidden_states.float()) + else: + router_logits, _ = gate(hidden_states) + if hasattr(self.quant_method, "quant_method") and \ + isinstance(self.quant_method.quant_method, + TorchairAscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) + + if shared_experts: + if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2: + # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce + shared_hidden_states = shared_experts(hidden_states) + + mc2_mask = forward_context.mc2_mask + + enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill + tp_size = get_tensor_model_parallel_world_size() + if enable_sp: + tp_rank = get_tensor_model_parallel_rank() + mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask + chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0) + mc2_mask = chunk_mc2_mask[tp_rank] + replace_allreduce = True + + if (fused_moe_state not in [ + FusedMoEState.AllGather, FusedMoEState.AllGatherEP, + FusedMoEState.NaiveMulticast + ]): + if tp_size > 1: + tp_rank = get_tensor_model_parallel_rank() + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) + mc2_mask = chunk_mc2_mask[tp_rank] + if not replace_allreduce: + if fused_moe_state in {FusedMoEState.MC2}: + padding_size = forward_context.padded_num_tokens + else: + # TODO: Determine if we can remove the padding + padding_size = tp_size + if num_tokens < padding_size and not self.enable_shared_expert_dp: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, padding_size - num_tokens)) + router_logits = nn.functional.pad( + router_logits, (0, 0, 0, padding_size - num_tokens)) + if tp_size > 1: + tp_rank = get_tensor_model_parallel_rank() + if not self.enable_shared_expert_dp: + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] + + if self.dp_size > 1: + if fused_moe_state == FusedMoEState.AllGather: + # NOTE: When in torchair graph, it has been padded in model_runner_v1 + if not self.torchair_graph_enabled: + max_tokens_across_dp = forward_context.max_tokens_across_dp + if num_tokens < max_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_tokens_across_dp - num_tokens)) + if not self.rm_router_logits: + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_tokens_across_dp - num_tokens)) + hidden_states = get_dp_group().all_gather(hidden_states, 0) + if self.rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = get_dp_group().all_gather(router_logits, 0) + + elif fused_moe_state == FusedMoEState.NaiveMulticast: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_sp(1) + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + if self.rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_dp_cpu) + + # Matrix multiply. + e_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance, + log2phy=self.log2phy, + global_redundant_expert_num=self.global_redundant_expert_num, + shared_experts=shared_experts if self.torchair_graph_enabled + and self.multistream_overlap_shared_expert and not is_prefill else + None, + mc2_mask=mc2_mask, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + prefix=self.prefix, + running_in_super_kernel=running_in_super_kernel, + ) + + if shared_experts: + if isinstance(e_hidden_states, + tuple) and len(e_hidden_states) == 2: + e_hidden_states, shared_hidden_states = e_hidden_states + + if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 4: + e_hidden_states, shared_hidden_states, group_list_type, expert_tokens = e_hidden_states + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + + if shared_experts is None and isinstance( + e_hidden_states, tuple) and len(e_hidden_states) == 3: + e_hidden_states, group_list_type, expert_tokens = e_hidden_states + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + + if (fused_moe_state not in [ + FusedMoEState.AllGather, FusedMoEState.AllGatherEP, + FusedMoEState.NaiveMulticast + ] and not replace_allreduce and not self.enable_shared_expert_dp): + if tp_size > 1: + if isinstance(e_hidden_states, tuple): + e_hidden_states = e_hidden_states[0] + dist.all_gather(list(chunk_hidden_states), e_hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + dispose_tensor(e_hidden_states) + else: + final_hidden_states = e_hidden_states + if num_tokens < padding_size: + final_hidden_states = final_hidden_states[:num_tokens] + elif self.dp_size > 1 and not self.enable_shared_expert_dp: + if fused_moe_state == FusedMoEState.NaiveMulticast: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + final_hidden_states = get_dp_group().all_reduce( + e_hidden_states) + final_hidden_states = final_hidden_states[start:end, :] + dispose_tensor(e_hidden_states) + elif fused_moe_state == FusedMoEState.AllGather: + final_hidden_states = get_dp_group().reduce_scatter( + e_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + dispose_tensor(e_hidden_states) + else: + final_hidden_states = e_hidden_states + else: + final_hidden_states = e_hidden_states + + if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.AllGatherEP, + FusedMoEState.NaiveMulticast + ]: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + if shared_experts: + return final_hidden_states, shared_hidden_states + else: + return final_hidden_states + + def update_expert_map(self, new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.log2phy + + def clear_moe_load(self): + if self.moe_load is not None: + self.moe_load.zero_() + + # ----------------------------------------- TBO-related -------------------------------------------- + + def _forward_ms_fused_moe_comp( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, + ): + hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance, + ) + + return hidden_states diff --git a/vllm_npu/torchair/ops/torchair_layernorm.py b/vllm_npu/torchair/ops/torchair_layernorm.py new file mode 100644 index 0000000..e67fc9f --- /dev/null +++ b/vllm_npu/torchair/ops/torchair_layernorm.py @@ -0,0 +1,78 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from typing import Optional, Tuple, Union + +import torch +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.layernorm import RMSNorm + +_original_re_init = RMSNorm.__init__ + + +def torchair_rmsnorm_init_( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, +) -> None: + _original_re_init(self, hidden_size, eps, var_hidden_size, has_weight, + dtype) + vllm_config = get_current_vllm_config() + self.bias = None + # quantization with anti_method m4 will generate none-zero norm bias + if vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + + +def torchair_rmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """AscendRMSNorm forward in torchair mode. + + The key difference from the original implementation is the removal of operators + from the torch.ops.vllm class, as these operators only function in non-torchair + modes. Adding them back would cause the graph compilation to fail. + """ + + import torch_npu + + from vllm_npu.utils import is_310p + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + if self.bias is not None: + x.add_(self.bias) + return x, residual + + x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + if self.bias is not None: + x.add_(self.bias) + return x diff --git a/vllm_npu/torchair/ops/torchair_rotary_embedding.py b/vllm_npu/torchair/ops/torchair_rotary_embedding.py new file mode 100644 index 0000000..fc854d7 --- /dev/null +++ b/vllm_npu/torchair/ops/torchair_rotary_embedding.py @@ -0,0 +1,365 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torch_npu +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.utils import enable_custom_op, is_310p + + +def custom_rotary_embedding_enabled(query, neox_style, head_size): + return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( + ) + + +def rope_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + is_qwen_torchair: Optional[bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if get_ascend_config( + ).torchair_graph_config.enabled and not is_qwen_torchair: + return self.forward_native( + positions, + query, + key, + offsets, + ) + + query_shape, key_shape = query.shape, key.shape + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + neox_style = self.is_neox_style + if is_neox_style_override is not None: + neox_style = is_neox_style_override + # adopt custom kernel path for rotary_embedding + if custom_rotary_embedding_enabled(query, neox_style, + self.head_size) and not is_310p(): + query, key = torch.ops._C_ascend.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + neox_style, + ) + return query.view(query_shape), key.view(key_shape) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + neox_style, + ) + return query.view(query_shape), key.view(key_shape) + + +def native_rope_deepseek_forward(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None): + if len(key.shape) == 2: + key = key[:, None, :] + # Note: we implement the non neox_style method with shuffle the last dim and neox style + # calculation method which is also more compute friendly to the ascend machine + # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py + neox_style = True + if self.is_neox_style is False: + b, h_q, d = query.shape + query = query.view(b, h_q, d // 2, 2).transpose(3, + 2).reshape(b, h_q, d) + b, h_k, d = key.shape + key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) + q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, + neox_style) + return q_pe, k_pe + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + return (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base))) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + # Note: use torch instead of math to solve MTP compilation error. + low = torch.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = torch.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + # Note: use torch instead of max/min to solve MTP compilation error. + return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) + + +def yarn_linear_ramp_mask(min_value, max_value, dim): + # Note: The if conditional branch is not used here + # to solve MTP compilation error. + max_value += (min_value == max_value).float() * 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - + min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids] + sin = sin[position_ids] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + if len(q.shape) == 3: + q = q[:, :, None, :] + if len(k.shape) == 2: + k = k[:, None, None, :] + elif len(k.shape) == 3: + k = k[:, :, None, :] + + b, h_q, s, d = q.shape + q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) + + b, h_k, s, d = k.shape + k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = q_embed.view(b, h_q, d) + k_embed = k_embed.view(b, h_k, d) + + return q_embed, k_embed + + +def _set_cos_sin_cache(self, max_seq_len, device, dtype): + dim = self.rotary_dim + + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale + sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale + cos_cached = cos_cached.to(dtype) + sin_cached = sin_cached.to(dtype) + cache = torch.cat([freqs.cos() * self.mscale, + freqs.sin() * self.mscale], + dim=-1).to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + +def __set_cos_sin_cache(self, seq_len, device, dtype): + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * + (1 / self.rotary_dim))) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) + self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) + self.embed = F.embedding + + +_original_re_init = RotaryEmbedding.__init__ + + +def qwen_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, +) -> None: + _original_re_init(self, head_size, rotary_dim, max_position_embeddings, + base, is_neox_style, dtype) + if get_ascend_config().torchair_graph_config.enabled: + __set_cos_sin_cache(self, + seq_len=max_position_embeddings, + device="npu", + dtype=dtype) + + +def rope_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + max_seq_len: Optional[int] = None, + is_prefill: Optional[bool] = True, + is_qwen_torchair: Optional[bool] = False, +): + if get_ascend_config().torchair_graph_config.enabled \ + and is_qwen_torchair and not is_prefill: + if max_seq_len is not None and torch.gt(max_seq_len, + self.max_position_embeddings): + __set_cos_sin_cache(self, + seq_len=max_seq_len, + device=query.device, + dtype=torch.float32) + + # bsnd/bnsd + if positions is not None: + cos = self.embed(positions, self.cos) + sin = self.embed(positions, self.sin) + self.cos_embed = cos + self.sin_embed = sin + else: + cos = self.cos_embed + sin = self.sin_embed + + query = query.view(*query.shape[:-1], -1, self.head_size).contiguous() + key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() + + cos = cos.unsqueeze(-2).unsqueeze(-2) + sin = sin.unsqueeze(-2).unsqueeze(-2) + + query = query.unsqueeze(1) + key = key.unsqueeze(1) + + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( + query, key, cos, sin) + return q_embed.flatten(-2), k_embed.flatten(-2) + else: + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override, + is_qwen_torchair) # type: ignore + + +def deepseek_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, +) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super(DeepseekScalingRotaryEmbedding, + self).__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) + _set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu") diff --git a/vllm_npu/torchair/ops/torchair_vocab_parallel_embedding.py b/vllm_npu/torchair/ops/torchair_vocab_parallel_embedding.py new file mode 100644 index 0000000..f83f2bc --- /dev/null +++ b/vllm_npu/torchair/ops/torchair_vocab_parallel_embedding.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vllm.distributed import tensor_model_parallel_all_reduce + + +def vocab_embedding_forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = self._get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output diff --git a/vllm_npu/torchair/quantization/__init__.py b/vllm_npu/torchair/quantization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_npu/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_npu/torchair/quantization/torchair_w4a8_dynamic.py new file mode 100644 index 0000000..d8111b4 --- /dev/null +++ b/vllm_npu/torchair/quantization/torchair_w4a8_dynamic.py @@ -0,0 +1,486 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import numpy as np +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.ascend_forward_context import FusedMoEState +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.torchair.quantization.torchair_w8a8_dynamic import ( + torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) +from vllm_npu.torchair.utils import npu_stream_switch, npu_wait_tensor + + +class TorchairAscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + self.new_quant_version = quant_version == "1.0.0" + + from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() + + def get_weight(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + + if self.new_quant_version: + pack_factor = 2 + actual_output_size = output_size // pack_factor + params_dict["weight"] = torch.empty(actual_output_size, + input_size, + dtype=torch.int8) + params_dict["_packed_dim"] = 0 + params_dict["_packed_factor"] = pack_factor + else: + params_dict["weight"] = torch.empty(output_size, + input_size, + dtype=torch.int8) + + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + + if self.new_quant_version: + scale_bias_dim = 16 if layer_type == "row" else 1 + params_dict["scale_bias"] = torch.empty(output_size, + scale_bias_dim, + dtype=torch.float32) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False): + k, n = weight.shape + group_num, n_scale = per_group_scale.shape + + if is_new_quant: + n = n * 2 + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, + ) + + if self.new_quant_version: + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + else: + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + + if self.new_quant_version: + assert layer.weight.data.shape[-1] % 4 == 0, \ + f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view( + torch.int32).contiguous() + else: + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) + + +class TorchairAscendW4A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W4A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process + self.is_per_channel_weight = self.group_size == 0 + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + # NOTE: new quantize weights: 2 int4 pack into int8 + self.new_quant_version = quant_version == "1.0.0" + self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size + if self.new_quant_version and self.tp_size > 16: + raise ValueError( + "The current weight does not support moe part tp>16.") + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + def get_weight(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + if self.new_quant_version: + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_sizes // 2 + else: + w13_output_size = 2 * intermediate_size_per_partition + w2_output_size = hidden_sizes + + param_dict["w13_weight"] = torch.empty(num_experts, + w13_output_size, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + def get_dynamic_quant_param(self, num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float32) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=torch.float32) + + if not self.is_per_channel_weight: + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + + if self.new_quant_version: + param_dict["w13_scale_bias"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + param_dict["w2_scale_bias"] = torch.empty(num_experts, + hidden_sizes, + 16 // self.tp_size, + dtype=torch.float32) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + if fused_moe_state == FusedMoEState.MC2: + return torchair_fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None), + dynamic_eplb=self.dynamic_eplb) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into layers module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return torchair_fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + scale = scale.transpose(1, 2).contiguous() + if self.is_per_channel_weight: + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() + group_num, k, n = weight.shape + # the weight of the new version is reduced by half by pack n, so it needs to be restored + if self.new_quant_version: + n = n * 2 + per_group_scale = per_group_scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = per_group_scale.shape + bias = None + if not self.new_quant_version: + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to( + torch.float32) + scale_fp32_np = scale_fp32.cpu().numpy() + scale_fp32_np.dtype = np.uint32 + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), + dtype=np.uint32) + + sscale_uint64[..., ::2] = scale_fp32_np + + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), + dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( + group_num, quantgroup_num, n) + sscale_uint64_tensor = sscale_uint64_tensor.npu() + return sscale_uint64_tensor, bias + + def update_bias(self, layer, w13_bias, w2_bias): + if self.new_quant_version: + layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + else: + w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + + def pack_to_int32(self, weight: torch.Tensor): + if self.new_quant_version: + # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + return weight.view(torch.int32).contiguous() + else: + return torch_npu.npu_quantize(weight.to(torch.float32), + torch.tensor([1.]).npu(), None, + torch.quint4x2, -1, False) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( + layer, "w13_weight_scale_second") else None + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( + layer, "w2_weight_scale_second") else None + layer.w13_weight_scale.data, w13_bias = self.process_scale( + layer.w13_weight, layer.w13_weight_scale.data, + w13_weight_scale_second) + layer.w2_weight_scale.data, w2_bias = self.process_scale( + layer.w2_weight, layer.w2_weight_scale.data, + w2_weight_scale_second) + if hasattr(layer, "w13_weight_scale_second"): + # scale_second is no longer used, release this part of the memory + del layer.w13_weight_scale_second + del layer.w2_weight_scale_second + del layer.w13_weight_offset_second + del layer.w2_weight_offset_second + + self.update_bias(layer, w13_bias, w2_bias) + + layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_npu/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_npu/torchair/quantization/torchair_w8a8_dynamic.py new file mode 100644 index 0000000..6baa92b --- /dev/null +++ b/vllm_npu/torchair/quantization/torchair_w8a8_dynamic.py @@ -0,0 +1,1064 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch_npu +from vllm.distributed import GroupCoordinator, get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.ascend_forward_context import FusedMoEState +from vllm_npu.distributed.parallel_state import get_mc2_group +from vllm_npu.torchair.utils import (npu_stream_switch, npu_wait_tensor, + super_kernel) +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, + dispose_tensor, get_ascend_soc_version, + is_enable_nz, + is_hierarchical_communication_enabled) + + +def torchair_apply_mlp_decode(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + Returns: + hidden_states: output hidden states after MLP. + """ + + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states + + +def torchair_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + hidden_states: input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + bias=bias2, + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + + return hidden_states + + +def torchair_fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: str = "", + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + is_torchair: bool = False, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + dynamic_eplb: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert mc2_mask is not None + if log2phy is not None: + topk_ids = log2phy[topk_ids] + + quant_mode = 2 + ep_group = get_mc2_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + need_expert_scale = is_hierarchical_communication_enabled() + + enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") + + if (expert_map is not None): + moe_expert_num = len(expert_map) + else: + moe_expert_num = global_redundant_expert_num + # hidden_states = hidden_states.bfloat16() + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) + if need_expert_scale: + stage1_kwargs.update({ + "expert_scales": topk_weights.to(torch.float32), + }) + kwargs_mc2.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ + ep_recv_counts, _, expand_scales = output[0:7] + + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_gate_up, expand_x) + shared_act_out = shared_experts.act_fn( + (shared_gate_up, shared_dequant_scale)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] + + # `expand_x` will be disposed in the `apply_mlp` function + if w1_scale_bias is None: + down_out_list = torchair_apply_mlp_decode(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) + else: + # w4a8 scene, cannot use apply_mlp_decode because the operator is not supported + down_out_list = torchair_apply_mlp(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) + + # moeCombine + kwargs_mc2 = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + "expand_scales": expand_scales, + } + if enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": assist_info_for_combine, + }) + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) + kwargs_mc2.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) + + if shared_experts is None: + if dynamic_eplb: + return (hidden_states, 1, expert_token_nums) + return hidden_states + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) + if dynamic_eplb: + return (hidden_states, shared_output, 1, expert_token_nums) + return (hidden_states, shared_output) + + +def torchair_init_routing_quant(hidden_states, top_k, topk_ids, + global_num_experts): + num_tokens, _ = hidden_states.shape + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=hidden_states.device).view( + top_k, -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( + 1, 0).contiguous().view(-1)) + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + global_expert_tokens = global_expert_tokens.to(torch.int32) + quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) + return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales + + +# currently expert parallelism implemented with all2all +# is under-optimized. +def torchair_fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, +): + if log2phy is not None: + topk_ids = log2phy[topk_ids] + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + + if expert_map is not None: + assert ep_group is not None, "ep_group must be provided when expert_map is given" + global_num_experts = len(expert_map) + if hasattr(torch_npu, "npu_moe_init_routing_quant"): + quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( + hidden_states, + expert_idx=topk_ids.to(torch.int32), + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_num_mode=2, + expert_tokens_before_capacity_flag=False, + quant_mode=1, + ) + else: + quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = torchair_init_routing_quant( + hidden_states, top_k, topk_ids, global_num_experts) + + gather_sizes = global_expert_tokens.new_empty( + global_expert_tokens.shape[0]) + dist.all_to_all_single(gather_sizes, + global_expert_tokens, + group=ep_group.device_group) + token_counts_combined = torch.stack( + [gather_sizes, global_expert_tokens], dim=0) + token_counts_combined = token_counts_combined.view( + 2, ep_group.world_size, -1).sum(dim=2) + token_counts_combined_cpu = token_counts_combined.to( + torch.device("cpu"), non_blocking=False).numpy() + all_tokens = gather_sizes.sum() + + gathered_tokens = quantized_tokens.new_empty(all_tokens.item(), + quantized_tokens.shape[1]) + dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0]) + gather_size_list = token_counts_combined_cpu[1] + scatter_size_list = token_counts_combined_cpu[0] + + dist.all_to_all_single(gathered_tokens, + quantized_tokens, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) + dist.all_to_all_single(dynamic_scale, + token_scales, + scatter_size_list, + gather_size_list, + group=ep_group.device_group) + + hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( + gathered_tokens, + gather_sizes.view(ep_group.world_size, -1), + per_token_scales=dynamic_scale) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 + dynamic_scale = None + + # `hidden_states` will be disposed in the `apply_mlp` function + hidden_states = torchair_apply_mlp( + hidden_states, + w1, + w1_scale, #17 + w2, + w2_scale, + expert_tokens, #16 + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) + + if expert_map is not None: + reordered_outputs = torch.index_select( + hidden_states, + dim=0, + # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU + index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) + + hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) + dist.all_to_all_single(hidden_states, + reordered_outputs, + gather_size_list, + scatter_size_list, + group=ep_group.device_group) + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=None, + drop_pad_mode=2) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +def torchair_fused_experts_with_allgather(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + num_tokens = hidden_states.shape[0] + batch_size, hidden_size = hidden_states.shape + topk_weights = topk_weights.to(hidden_states.dtype) + + ep_group = get_ep_group().device_group + ep_rank = torch.distributed.get_rank(group=ep_group) + ep_size = torch.distributed.get_world_size(ep_group) + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_size + + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + + hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + scale=pertoken_scale, + offset=None, + active_num=num_tokens * top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[ + ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts + ], + quant_mode=-1, + row_idx_type=1) + group_list_type = 1 + + sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0, + expanded_x_idx) + row_index = expanded_x_idx // topk_ids.shape[-1] + row_index = row_index.to(torch.int64) + share_input = torch.zeros((batch_size, hidden_size), + dtype=torch.bfloat16, + device="npu") + + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=expert_tokens, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale.to(torch.float32), + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_tokens, + activate_left=True, + quant_mode=1, + ) + + final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing( + hidden_states, + w2, + scale=w2_scale.to(torch.float32), + bias=None, + pertoken_scale=pertoken_scale.view(-1), + group_list=expert_tokens, + shared_input=share_input, + logit=sorted_topk_weight.to(torch.float32), + row_index=row_index, + output_bs=batch_size).to(torch.bfloat16) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + + return final_hidden_states + + +def torchair_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + dtype = hidden_states.dtype + device = hidden_states.device + + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + filtered_weights = torch.where( + mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) + filtered_experts = torch.where( + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts) + sorted_token_indices = token_indices[sort_indices] + sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + expert_tokens = token_counts[:num_experts] + # Rearrange hidden_states + hidden_states = hidden_states[sorted_token_indices] + group_list_type = 1 + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 + + # `hidden_states` will be disposed in the `apply_mlp` function + hidden_states = torchair_apply_mlp(hidden_states, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) + + if expert_map is not None: + hidden_states.mul_(sorted_weights.unsqueeze(1)) + final_hidden_states = torch.zeros(*original_shape, + device=device, + dtype=dtype) + + num_valid_tokens = mask.sum() + valid_token_mask = torch.arange( + 0, sorted_token_indices.shape[0], + device=device).unsqueeze(1) < num_valid_tokens + hidden_states = hidden_states.masked_fill_(~valid_token_mask, + 0).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +class TorchairAscendW8A8DynamicLinearMethod: + """Linear method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + config = getattr(layer, "_ascend_quant_config", {}) + if not isinstance(x, tuple): + output_dtype = config.get("output_dtype", x.dtype) + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + else: + assert "output_dtype" in config.keys(), ( + f"DynamicLinearMethod needs explicitly specified `output_dtype`" + f"for pre-quantized input, got config [{config}]") + output_dtype = config["output_dtype"] + quantized_x, dynamic_scale = x + pertoken_scale = (dynamic_scale + if config.get("pertoken_scale", True) else None) + + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + pertoken_scale=pertoken_scale, + bias=bias, + output_dtype=output_dtype, + ) + return ((output, dynamic_scale) + if config.get("return_scale", False) else output) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # cast quantized weight tensors in NZ format (29) for higher inference speed + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, 29) + layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + + +class TorchairAscendW8A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + prefix: str = "", + running_in_super_kernel: bool = False, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All + shared_gate_up, shared_dequant_scale = None, None + + with super_kernel(prefix, + "stream-fusion=1", + enabled=running_in_super_kernel): + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like( + topk_ids, 0, + global_num_experts - global_redundant_expert_num) + topk_weights = topk_weights.to(x.dtype) + + if fused_moe_state == FusedMoEState.AllGatherEP: + return torchair_fused_experts_with_allgather( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + elif fused_moe_state == FusedMoEState.MC2: + with super_kernel(prefix, + "stream-fusion=1", + enabled=running_in_super_kernel): + return torchair_fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_fp32, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + mc2_mask=kwargs.get("mc2_mask", None), + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale, + dynamic_eplb=self.dynamic_eplb) + elif fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.NaiveMulticast + ]: + return torchair_fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into layers module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return torchair_fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + if is_enable_nz(): + torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( + layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( + layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) diff --git a/vllm_npu/torchair/torchair_attention.py b/vllm_npu/torchair/torchair_attention.py new file mode 100644 index 0000000..3e469f2 --- /dev/null +++ b/vllm_npu/torchair/torchair_attention.py @@ -0,0 +1,463 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +import torch_npu +from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, + AttentionType) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.utils import cdiv + +from vllm_npu.attention.attention_v1 import (AscendAttentionBackend, + AscendAttentionMetadataBuilder, + AscendAttentionState, + AscendMetadata) +from vllm_npu.attention.utils import AscendCommonAttentionMetadata +from vllm_npu.torchair.utils import TorchairCommonAttentionMetadata +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, + nd_to_nz_2d) + + +class AscendAttentionTorchairBackend(AscendAttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_TORCHAIR" + + @staticmethod + def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]: + return AscendAttentionTorchairBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AscendTorchairMetadata"]: + return AscendTorchairMetadata + + @staticmethod + def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]: + return AscendAttentionTorchairMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads * head_size) + + @staticmethod + def get_bsh_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size, num_kv_heads * head_size) + + +@dataclass +class AscendDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendTorchairMetadata(AscendMetadata): + + decode: Optional[AscendDecodeMetadata] = None + + +class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): + + def __init__( + self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.max_num_blocks_per_req = cdiv( + self.model_config.max_model_len, + self.vllm_config.cache_config.block_size) + self.max_blocks = (self.model_config.max_model_len + + self.vllm_config.cache_config.block_size - + 1) // self.vllm_config.cache_config.block_size + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] + + def build_torchair_graph_dummy( + self, common_attn_metadata: TorchairCommonAttentionMetadata + ) -> AscendTorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) + input_positions = torch.zeros(num_reqs, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_reqs, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + + decode_metadata = AscendDecodeMetadata(input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_seq_lens=1) + + attn_metadata = AscendTorchairMetadata( + num_actual_tokens=common_attn_metadata.num_actual_tokens, + block_tables=block_table, + query_lens=0, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + attn_state=AscendAttentionState.DecodeOnly, + decode=decode_metadata) + return attn_metadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: Optional[nn.Module] = None, + ): + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + + block_table = common_attn_metadata.block_table_tensor + block_table[:num_reqs, :self.max_num_blocks_per_req] = ( + block_table[:num_reqs]) + + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + attn_mask = common_attn_metadata.attn_mask + + attn_state = common_attn_metadata.attn_state + if is_310p() and attn_state == AscendAttentionState.PrefillNoCache: + mask_nz = nd_to_nz_2d(attn_mask) + attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_reqs + + 1] + query_start_loc = query_start_loc_cpu.to(self.device, + non_blocking=True) + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size + use_torchair_graph = graph_pad_size > -1 + if common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + ]: + max_seq_lens = seq_lens.max().item() + num_seqs = len(seq_lens) + if use_torchair_graph and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + ]: + num_reqs_pad_size = 0 + num_token_pad_size = 0 + if graph_pad_size != 0: + pad_value = 0 + num_token_pad_size = graph_pad_size - num_actual_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + pad_value = 1 + padded_seq_lens = seq_lens.tolist() + [pad_value + ] * num_reqs_pad_size + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)) + padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, padding]) + block_table_padding = torch.zeros( + (num_reqs_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_seqs + num_reqs_pad_size, block_table) + padding_0 = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat([input_positions, padding_0]) + + decode_metadata = AscendDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_seq_lens=max_seq_lens, + attn_mask=None) + + attn_metadata = AscendTorchairMetadata( + decode=decode_metadata, + num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + max_query_len=common_attn_metadata.max_query_len, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) + return attn_metadata + + +class AscendAttentionTorchairBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.hidden_size = self.num_heads * self.head_size + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, + dtype=torch.float32, + device="npu") + self.alibi_slopes = alibi_slopes + self.attn_type = attn_type + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.key_cache = None + self.value_cache = None + self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AscendTorchairMetadata, + output: Optional[torch.Tensor] = None, + trace_flag: bool = False, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache: shape = [2, num_blocks, block_size, + num_kv_heads, head_size] + key_cache = [num_blocks, block_size, + num_kv_heads, head_size] + value_cache = [num_blocks, block_size, + num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size * seq_len, num_heads, head_size] + """ + num_tokens = query.shape[0] + use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0 + and kv_cache[0].numel() > 0 + and kv_cache[0].dtype == torch.int8) + if output is None: + output = torch.empty(num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + + if hasattr(layer, 'quant_method') and use_kv_cache_quant: + output = layer.quant_method.apply(layer, query, key, value, + kv_cache, attn_metadata, + self.attn_type, self.scale, + output) + return output.view(num_tokens, self.hidden_size) + + if attn_metadata is None: + return output.view(num_tokens, self.hidden_size).fill_(0) + + output = output.view(-1, self.num_heads, self.head_size) + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "AscendAttentionTorchairBackendImpl") + + if kv_cache is not None and kv_cache[0].numel() > 0: + key_cache, value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + + block_size = self.scale_tensor + key_cache.shape[1] + slots_indices = slots.reshape(-1, 1) + block_indices = slots_indices // block_size + slots_indices = slots_indices % block_size + indices = torch.cat((block_indices, slots_indices), dim=1) + torch_npu.npu_scatter_nd_update_(key_cache, indices, key) + torch_npu.npu_scatter_nd_update_(value_cache, indices, value) + if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + self.key_cache = key_cache + self.value_cache = value_cache + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + + # View q k v to BSH. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if is_310p(): + # align q k v output tensors + query = aligned_16(query) + key = aligned_16(key) + value = aligned_16(value) + output = aligned_16(output) + + # do reformat in case of broadcasted tensors + mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) + mask = torch_npu.npu_format_cast(mask.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + output = output[:num_tokens, :, :] + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_table, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + decode_meta = attn_metadata.decode + assert decode_meta is not None + seq_lens = decode_meta.seq_lens_list + block_table = decode_meta.block_table + block_size = key_cache.shape[1] + query = query.view(num_tokens, 1, + self.num_heads * self.head_size).contiguous() + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key_cache, + value=value_cache, + query_rope=None, + key_rope=None, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout='BSH', + atten_mask=decode_meta.attn_mask, + sparse_mode=0, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens, + ) + else: + raise NotImplementedError( + "Torchair graph mode with non-MLA attention backend is still experimental." + "v1 scheduler(chunked prefill) is not supported at this moment. Please" + "setting 'ascend_scheduler_config':{'enabled':true} in additional_config" + "to use ascend scheduler.") + + return output.view(num_tokens, self.hidden_size) diff --git a/vllm_npu/torchair/torchair_mla.py b/vllm_npu/torchair/torchair_mla.py new file mode 100644 index 0000000..92ce979 --- /dev/null +++ b/vllm_npu/torchair/torchair_mla.py @@ -0,0 +1,1310 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn as nn +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_npu.multistream.base import MSAttentionMetadataSplitConfig +from vllm_npu.multistream.context import get_multistream_comm_context +from vllm_npu.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_npu.ops.weight_prefetch import maybe_npu_prefetch +from vllm_npu.torchair.utils import (TorchairCommonAttentionMetadata, + npu_stream_switch, npu_wait_tensor) +from vllm_npu.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendMLATorchairBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_MLA_TORCHAIR" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendMLATorchairMetadata + + @staticmethod + def get_builder_cls(): + return AscendMLATorchairMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendMLATorchairImpl + + +@dataclass +class AscendMLATorchairPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class TorchairChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor + + attn_mask: torch.Tensor + query_lens: torch.Tensor + seq_lens: list[int] + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + chunked_context: Optional[TorchairChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None + + +@dataclass +class AscendMLATorchairDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: Optional[list[int]] = None + attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None + + +@dataclass +class AscendMLATorchairMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendMLATorchairDecodeMetadata] = None + prefill: Optional[AscendMLATorchairPrefillMetadata] = None + enable_dbo_across_dp: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMLATorchairMetadata"]: + """Split metadata for multi-stream with AscendMLATorchairMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLATorchairMetadata, + ) + + +M = TypeVar("M", bound=AscendMLATorchairMetadata) + + +class AscendMLATorchairMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendMLATorchairMetadata] = None): + self.metadata_cls: Optional[AscendMLATorchairMetadata] = metadata_cls \ + if metadata_cls is not None else AscendMLATorchairMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + # For torch air graph mode we treat spec decoding as decode. + if self.torchair_graph_enabled: + if num_tokens - num_spec_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + # For eager mode we treat spec decoding as chunked prefill. + else: + if num_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] + + def build_torchair_graph_dummy( + self, + common_attn_metadata: TorchairCommonAttentionMetadata, + ) -> AscendMLATorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = [0] * num_reqs + input_positions = torch.zeros(num_tokens, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_tokens, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + sin = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + cos = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 + decode_metadata = AscendMLATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=1, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=common_attn_metadata. + actual_seq_lengths_q[:num_reqs], + sin=sin, + cos=cos, + ) + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=1, + num_decode_tokens=num_decode_tokens, + num_prefills=0, + attn_mask=common_attn_metadata.attn_mask, + attn_state=attn_state, + prefill=None, + decode=decode_metadata, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_tables=block_table, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendMLATorchairMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + prefill_metadata = AscendMLATorchairPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens[tokens_start:].to(torch.int32), + seq_lens=seq_lens, + context_lens=seq_lens[tokens_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + + decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size + use_torchair_graph = graph_pad_size != -1 + if num_decodes > 0: + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes] + input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + num_token_pad_size = 0 + if use_torchair_graph and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + num_reqs_pad_size = 0 + if graph_pad_size != 0: + pad_value = 0 + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size + else: + padded_seq_lens = seq_lens.tolist() + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)) + seq_lens_list = padded_seq_lens + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_reqs_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_reqs + num_reqs_pad_size, block_table) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + actual_seq_lengths_q = self.pad_actual_seq_len_q( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + common_attn_metadata) + else: + seq_lens_list = seq_lens.tolist() + # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) + batch_size = num_decode_tokens + num_token_pad_size + if actual_seq_lengths_q[-1] != batch_size \ + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + actual_seq_lengths_q[-1] = batch_size + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendMLATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + ) + + def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + return actual_seq_lengths_q + + +class AscendMLATorchairImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ + 'q_b_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.running_in_graph = False + self.prefill_mask = None + self.ring_mla_mask_size = 512 + + self.speculative_config = get_current_vllm_config().speculative_config + + def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if hasattr(self, "running_in_graph") and not self.running_in_graph: + return x + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + maybe_npu_prefetch(self.o_proj.weight, + x, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + return self.o_proj(x, is_prefill=False)[0] + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + def _compute_prefill_context( + self, + query: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + rope_dim: int, + attn_metadata: AscendMLATorchairMetadata, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + ): + assert len(kv_c_and_k_pe_cache) > 1 + prefill_metadata = attn_metadata.prefill + if prefill_metadata is None or prefill_metadata.chunked_context is None: + return prefix_output, prefix_lse + + iters = len(prefill_metadata.chunked_context.seq_tot) + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) + kv_c_normed = torch.empty(toks, + num_heads, + latent_kv_dim, + dtype=query.dtype, + device=query.device) + k_pe = torch.empty(toks, + num_heads, + rope_dim, + dtype=query.dtype, + device=query.device) + + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + context_seq_len_npu, + seq_starts=prefill_metadata.chunked_context.starts[i], + key=kv_c_normed, + value=k_pe, + ) + + kv_c_normed = kv_c_normed.squeeze() + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=self.prefill_mask, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) + return prefix_output, prefix_lse + + def _forward_prefill( + self, + query: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLATorchairMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 + + num_tokens = query.size(0) + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=query.dtype, + device=query.device) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=query.device) + k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + if self.prefill_mask is None: + if q_nope.dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(q_nope.dtype) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=attn_metadata.prefill.query_lens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + + attn_output = attn_output.reshape( + [num_tokens, self.num_heads * self.v_head_dim]) + + return attn_output + + def exec_kv( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + return k_pe, k_nope, kv + + def exec_kv_prefill( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, + ) + return k_pe, k_nope + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLATorchairMetadata, + enable_multistream_mla: bool = False, + ) -> torch.Tensor: + decode_meta = attn_metadata.decode + assert decode_meta is not None + num_tokens = q_nope.size(0) + if self.running_in_graph or self.running_chunkprefilll_with_torchair: + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + block_size = kv_c_and_k_pe_cache[0].shape[1] + actual_seq_lengths = None + if self.enable_kv_nz: + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" + + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill + ] and self.speculative_config is not None: + # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill + input_layout = "TND" + # [bs * q_seq_len, num_heads_per_rank, dim] + q_nope = q_nope.view(num_tokens, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=decode_meta.block_table, + block_size=block_size, + actual_seq_lengths_kv=decode_meta.seq_lens_list, + actual_seq_lengths=actual_seq_lengths) + else: + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become + # public available + assert len(kv_c_and_k_pe_cache) > 1 + if envs_ascend.vllm_npu_MLA_PA: + attn_output = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, q_pe, kv_c_and_k_pe_cache[0], + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, self.num_heads, self.scale, + self.num_kv_heads) + else: + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=k_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode. + block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj_and_o_proj(attn_output, + enable_multistream_mla) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj_and_o_proj(attn_output) + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: Tuple[torch.Tensor], + attn_metadata: M, + output: Optional[torch.Tensor] = None, + enable_multistream_mla: bool = False, + ckq: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill + num_actual_toks = attn_metadata.num_actual_tokens + if k_pe is None and not self.running_in_graph: + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + else: + kv_c_normed = hidden_states_or_kv_c_normed + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + if not self.running_in_graph: + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + if not self.torchair_graph_enabled: + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + if not self.running_in_graph: + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] + else: + decode_hs_or_q_c = hidden_states_or_q_c + if has_decode: + decode_k_nope = None + assert attn_metadata.decode is not None + if self.running_in_graph or self.running_chunkprefilll_with_torchair: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + if self.running_chunkprefilll_with_torchair: + decode_hs = ( + hidden_states_or_kv_c_normed[:num_decode_tokens]) + slots = attn_metadata.slot_mapping[:num_decode_tokens] + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + decode_hs, cos, sin, kv_cache, slots) + else: + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(hidden_states_or_kv_c_normed, + ckq, + enabled=enable_multistream_mla) + decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) + # Without explicitly controlling the order, IndexByTensor operations + # would be placed after `matmul W_KV_T` hindering the overlapping of + # KvRmsNormRopeCache and SingleRope. + npu_wait_tensor(decode_hs_or_q_c, + cos, + enabled=enable_multistream_mla) + npu_wait_tensor(decode_hs_or_q_c, + sin, + enabled=enable_multistream_mla) + npu_wait_tensor(decode_hs_or_q_c, + decode_kv, + enabled=enable_multistream_mla) + + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_hs_or_q_c) + if self.running_in_graph: + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + npu_wait_tensor(decode_q_pe, + decode_k_pe, + enabled=enable_multistream_mla) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + elif self.running_chunkprefilll_with_torchair: + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + else: + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), decode_k_pe) + if has_prefill: + assert attn_metadata.prefill is not None + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + if self.torchair_graph_enabled: + num_tokens = prefill_hs_or_q_c.shape[0] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) + + kv_c_normed = prefill_k_nope[:num_actual_toks, ...] + prefill_k_c_normed = prefill_k_nope + prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, + -1) + prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) + else: + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), prefill_k_pe) + + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" + if self.torchair_graph_enabled: + if kv_cache[0].numel() > 0 and has_prefill: + slots = attn_metadata.slot_mapping + # NOTE: Separate the kv cache in advance to avoid OOM or other issues + torch_npu._npu_reshape_and_cache( + key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), + value=prefill_k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots[num_decode_tokens:]) + else: + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) + if not self.running_in_graph: + o_proj_input_shape = (num_actual_toks, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + if has_prefill: + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + o_proj_input[num_decode_tokens:] = output_prefill + else: + o_proj_input[num_decode_tokens:] = output_prefill + + if has_decode: + if self.running_in_graph: + return self._forward_decode(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe, + kv_cache, attn_metadata, + enable_multistream_mla) + else: + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + o_proj_input[:num_decode_tokens] = output_decode + else: + o_proj_input[:num_decode_tokens] = output_decode + + current_ms_metadata = get_multistream_comm_context() + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + if current_ms_metadata is None: + maybe_npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + else: + with torch.npu.stream(current_ms_metadata.comm_stream): + maybe_npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + current_ms_metadata.after_comm_event.record() + del o_proj_input + return output_padded diff --git a/vllm_npu/torchair/torchair_model_runner.py b/vllm_npu/torchair/torchair_model_runner.py new file mode 100644 index 0000000..d4c2633 --- /dev/null +++ b/vllm_npu/torchair/torchair_model_runner.py @@ -0,0 +1,557 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py +# isort: skip_file + +import math +import types +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch_npu +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.logger import logger + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.platform import NPUPlatform +from vllm_npu.torchair.utils import ( + TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata, + check_torchair_cache_exist, converting_weight_acl_format, + register_torchair_model, torchair_ops_patch, + torchair_quant_method_register, write_kv_cache_bytes_to_file) +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, + is_310p, get_ascend_soc_version, + AscendSocVersion) +from vllm_npu.worker.model_runner_v1 import NPUModelRunner + + +class NPUTorchairModelRunner(NPUModelRunner): + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.ascend_config = get_ascend_config() + self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp + super().__init__(vllm_config, device) + if self.speculative_config: + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + None, None, vllm_config, device) + + register_torchair_model() + torchair_ops_patch() + torchair_quant_method_register() + if self.enable_shared_expert_dp: + return + self.new_kv_cache_bytes = -1 + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore + self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph + self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes + self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes + if self.ascend_config.torchair_graph_config.graph_batch_sizes_init: + self.init_torchair_graph_batch_sizes() + + self.update_torchair_graph_batch_sizes() + + torch._dynamo.cache_size.config.cache_size_limit += len( + self.torchair_graph_batch_sizes) + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._logging.set_logs( + recompiles=envs_ascend.vllm_npu_TRACE_RECOMPILES) + + self._check_batch_sizes_consistency() + + def _may_pad_kv_consumer_num_seq(self): + # pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens + # self.max_num_reqs here is greater than the actual maximum request number + if self.decode_token_per_req > 1 and self.is_kv_consumer: + # applied only when speculative decoding is active + FIA_SEQ_LEN_LIMIT = 16 + new_max_num_reqs = self.max_num_reqs + math.ceil( + self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil( + (self.max_num_reqs * self.decode_token_per_req) / + (FIA_SEQ_LEN_LIMIT**2)) + if self.max_num_reqs < new_max_num_reqs: + logger.warning( + f"max_num_reqs is updated to {new_max_num_reqs}") + self.max_num_reqs = new_max_num_reqs + + def _init_mc2_tokens_capacity(self): + # NOTE: To be clear, we need to make sure that during graph capture, the number of + # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes, + # the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512). + max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len + tp_size = self.parallel_config.tensor_parallel_size + max_graph_batch_size = self.calculate_new_torchair_graph_batch_size( + max_num_tokens, tp_size) + self.mc2_tokens_capacity = max_graph_batch_size + + if get_ascend_soc_version( + ) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512: + logger.error( + f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}" + ) + if get_ascend_soc_version( + ) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256: + logger.error( + f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}" + ) + + def _sync_metadata_across_dp( + self, num_tokens: int, with_prefill: bool, enable_dbo: bool + ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + """Override from NPUModelRunner to pad num_tokens""" + if self.enable_shared_expert_dp: + # Padding is not required for shared_expert_dp cases in eager mode. + return num_tokens, None, with_prefill, enable_dbo + if self.dp_size == 1: + if not with_prefill: + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + num_tokens) + return maybe_padded_num_tokens, None, with_prefill, enable_dbo + return num_tokens, None, with_prefill, enable_dbo + + num_tokens_across_dp = torch.zeros(self.dp_size + 2, + dtype=torch.int32, + device="npu") + num_tokens_across_dp[self.dp_rank] = num_tokens + num_tokens_across_dp[-2] = int(with_prefill) + num_tokens_across_dp[-1] = int(not enable_dbo) + dist.all_reduce(num_tokens_across_dp, + group=get_dp_group().device_group) + with_prefill = bool(num_tokens_across_dp[-2]) + enable_dbo = not bool(num_tokens_across_dp[-1]) + num_tokens_across_dp = num_tokens_across_dp[:-2] + + if not with_prefill: + max_num_token = num_tokens_across_dp.max().item() + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + max_num_token) + num_tokens_across_dp = torch.full((self.dp_size, ), + maybe_padded_num_tokens, + dtype=torch.int32, + device="npu") + else: + maybe_padded_num_tokens = num_tokens + + return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + + def _build_dummy_attn_metadata( + self, + with_prefill: bool, + num_reqs: int, + num_tokens: int, + max_query_len: int, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + ) -> Optional[dict[str, Any]]: + # NOTE: If torchair graph mode and not with_prefill, + # we can't skip_attn, it will cause graph recompile. + if with_prefill or self.enable_shared_expert_dp: + attn_metadata = super()._build_dummy_attn_metadata( + with_prefill, num_reqs, num_tokens, max_query_len, + aclgraph_runtime_mode, force_attention) + else: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.actual_seq_lengths_q, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + decode_token_per_req=self.decode_token_per_req, + ) + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) + return attn_metadata + + def _generate_dummy_run_hidden_states(self, with_prefill, + is_torchair_compile, input_ids, + positions, attn_metadata, num_tokens, + intermediate_tensors, inputs_embeds): + if with_prefill or self.enable_shared_expert_dp: + if is_310p(): + converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + hidden_states = super()._generate_dummy_run_hidden_states( + with_prefill, is_torchair_compile, input_ids, positions, + attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) + else: + # Only mark static while compiling + if is_torchair_compile: + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static(attn_metadata.decode.input_positions) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + if self.speculative_config: + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + for kv in self.kv_caches: + assert isinstance(kv, tuple), "kv_cache must be a tuple" + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + if is_310p(): + converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ) + + compiled_model = self._get_torchair_lazy_compiled_model(num_tokens) + model_kwargs = {} + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) + return hidden_states + + def _convert_torch_format(self, kv_cache): + if self.enable_shared_expert_dp: + return super()._convert_torch_format(kv_cache) + kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND) + return kv_cache + + def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None: + # Trigger torchair graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens, is_torchair_compile=True) + self._dummy_run(num_tokens, is_torchair_compile=True) + logger.info("Batchsize %d is compiled successfully: %d/%d.", + num_tokens, idx + 1, len(torchair_graph_batch_sizes)) + + def _capture_model(self): + """Override from NPUModelRunner to use torchair graph capture.""" + if self.enable_shared_expert_dp: + return super()._capture_model() + # TODO(NeverRaR): Calling graph_capture(device=self.device) in + # torchair graph capture can cause some issues, so now we just + # temporarily split the codepath for the two different graph patterns. + torchair_graph_batch_sizes = self.torchair_graph_batch_sizes + graph_num = len(torchair_graph_batch_sizes) + + if self.use_cached_npu_graph and not check_torchair_cache_exist(): + # If caching is enabled but does not exist (either + # use_cached_kv_cache_bytes is disabled or kv_cache_bytes are + # different), we will compile the model twice. The first time is + # used to generate the cache, and the second time is used to load the + # cache to skip the overhead caused by Dynamo guard mechanism. + logger.info( + "Cache compilation for torchair graph is enabled. Now we compile graph to genetate" + " torchair cache, this usually takes %.1f~%.1f mins.", + 0.5 * graph_num, 1.5 * graph_num) + self._compile_torchair_graph(torchair_graph_batch_sizes) + NPUPlatform.synchronize() + # Note: We reset dynamo and reload the compiled torchair cached computation graph below + # that was compiled above. This operation reduces graph launch time by 2-4ms and avoids + # runtime errors caused by configuration mismatches in graph mode. + torch._dynamo.reset() + self.torchair_compiled_models.clear() + if self.use_cached_npu_graph: + logger.info( + "Loading torchair graph cache, this usually takes %.1f~%.1f mins.", + 0.3 * graph_num, 0.5 * graph_num) + self._compile_torchair_graph(torchair_graph_batch_sizes) + else: + logger.info( + "Capturing torchair graph, this usually takes %.1f~%.1f mins.", + 0.5 * graph_num, 1.5 * graph_num) + self._compile_torchair_graph(torchair_graph_batch_sizes) + + if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0: + write_kv_cache_bytes_to_file(torch.distributed.get_rank(), + self.new_kv_cache_bytes) + + def _use_aclgraph(self) -> bool: + if self.enable_shared_expert_dp: + return super()._use_aclgraph() + return False + + def _check_batch_sizes_consistency(self) -> None: + if not dist.is_initialized(): + return + + local = torch.tensor(self.torchair_graph_batch_sizes, + device="cpu", + dtype=torch.int32) + gathered_graph_batch_size = local.clone() + dist.all_reduce(gathered_graph_batch_size, + group=get_dp_group().cpu_group) + expected = local * self.dp_size + + if not torch.equal(gathered_graph_batch_size, expected): + diff_idxs = (gathered_graph_batch_size != expected).nonzero( + as_tuple=False).flatten().tolist() + raise AssertionError( + f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n" + f"Local (rank {self.dp_rank}): {local.tolist()}\n" + f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n" + f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}" + ) + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + if with_prefill or self.enable_shared_expert_dp: + super()._update_graph_pad_size(with_prefill, graph_pad_size) + else: + self.graph_pad_size = graph_pad_size + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + """Override from NPUModelRunner to update input_ids and positions""" + input_ids, positions = super()._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) + + if with_prefill or self.enable_shared_expert_dp: + return input_ids, positions + else: + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] + return input_ids, positions + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + + if self.enable_shared_expert_dp: + return super()._generate_process_reqs_hidden_states( + attn_metadata, with_prefill, padded_num_tokens_across_dp, + input_ids, positions, intermediate_tensors, inputs_embeds) + model_kwargs = { + "kv_caches": self.kv_caches, + "attn_metadata": attn_metadata + } + if not with_prefill: + if is_310p(): + converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ) + compiled_model = self._get_torchair_lazy_compiled_model( + padded_num_tokens_across_dp) + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + else: + assert self.model is not None + if is_310p(): + converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + return hidden_states + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + + if is_310p(): + # on 300I Duo platform, we need to patch broadcast. however, this patch will be + # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. + from vllm_npu.patch.platform.patch_distributed import \ + communication_adaptation_310p + communication_adaptation_310p() + + config = torchair.CompilerConfig() + if self.ascend_config.torchair_graph_config.mode: + config.mode = self.ascend_config.torchair_graph_config.mode + config.experimental_config.frozen_parameter = \ + self.ascend_config.torchair_graph_config.enable_frozen_parameter + # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to + # disable it on 300I Duo platform now. + config.experimental_config.tiling_schedule_optimize = not is_310p() + config.experimental_config.enable_view_optimize = \ + self.ascend_config.torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=not self.use_sparse, + fullgraph=True, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=not self.use_sparse, + fullgraph=True, + cache_dir=TORCHAIR_CACHE_DIR, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] + + def init_torchair_graph_batch_sizes(self): + start_graph_batch_size = 4 + tp_size = get_tensor_model_parallel_world_size() + + # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks + start_graph_batch_size = max(start_graph_batch_size, tp_size) + + while (start_graph_batch_size <= self.max_num_reqs): + self.torchair_graph_batch_sizes.append(start_graph_batch_size) + start_graph_batch_size *= 2 + + def select_torchair_padded_batch_size(self, batch_size: int): + for padded_batch_size in self.torchair_graph_batch_sizes: + if batch_size <= padded_batch_size: + # we treat batch_size as num of requests + return padded_batch_size + raise ValueError( + f"cur batch_size is invalid, torchair_graph_batch_sizes is " + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." + ) + + def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size, + tp_size): + cur_graph_batch_size = (old_graph_batch_size + tp_size - + 1) // tp_size * tp_size + # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size, + # Both adapter multi-dp and FIA operator + if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1: + cur_graph_batch_size = (tp_size * old_graph_batch_size) \ + // math.gcd(tp_size, old_graph_batch_size) + return cur_graph_batch_size + + def update_torchair_graph_batch_sizes(self): + # return graph_batch_sizes according to the max number of tokens + # first pad according to the number of requests + if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs + self.torchair_graph_batch_sizes = [self.max_num_reqs] + logger.warning( + f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}" + ) + elif len(self.torchair_graph_batch_sizes) == 0: + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + else: + self.torchair_graph_batch_sizes = sorted( + self.torchair_graph_batch_sizes) + while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: + self.torchair_graph_batch_sizes.pop() + if len(self.torchair_graph_batch_sizes) == 0: + logger.warning( + "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" + ) + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: + self.torchair_graph_batch_sizes.append(self.max_num_reqs) + + # padded max number tokens = max_num_req * decode_token_per_req + self.torchair_graph_batch_sizes = [ + graph_batch_size * self.decode_token_per_req + for graph_batch_size in self.torchair_graph_batch_sizes + ] + + # NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size` + # Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same + # on all EP ranks + if get_ascend_soc_version( + ) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel: + self._align_graph_size_divisible_by_tp_size() + + def _align_graph_size_divisible_by_tp_size(self): + tp_size = self.parallel_config.tensor_parallel_size + new_graph_batch_sizes = [] + for graph_batch_size in self.torchair_graph_batch_sizes: + cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size( + graph_batch_size, tp_size) + if cur_graph_batch_size not in new_graph_batch_sizes and \ + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: + new_graph_batch_sizes.append(cur_graph_batch_size) + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ + and self.decode_token_per_req > 1: + logger.warning( + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." + ) + new_max_num_reqs = math.ceil( + max(new_graph_batch_sizes) / self.decode_token_per_req) + if self.max_num_reqs != new_max_num_reqs: + logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}") + self.max_num_reqs = new_max_num_reqs + if not (self.decode_token_per_req > 1 and self.is_kv_consumer): + # Do not update scheduler_config.max_num_seqs in KV consumer + MTP + # Since FIA need extra space for padding + # Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP + self.scheduler_config.max_num_seqs = new_max_num_reqs + + if new_graph_batch_sizes != self.torchair_graph_batch_sizes: + logger.warning( + f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}." + ) + self.torchair_graph_batch_sizes = new_graph_batch_sizes + + def _build_drafter_prepare_inputs_torchair_param(self): + if self.enable_shared_expert_dp: + return super()._build_drafter_prepare_inputs_torchair_param() + else: + return True diff --git a/vllm_npu/torchair/torchair_sfa.py b/vllm_npu/torchair/torchair_sfa.py new file mode 100644 index 0000000..fdcaaad --- /dev/null +++ b/vllm_npu/torchair/torchair_sfa.py @@ -0,0 +1,1333 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.attention.utils import (AscendCommonAttentionMetadata, + split_decodes_and_prefills) +from vllm_npu.multistream.base import MSAttentionMetadataSplitConfig +from vllm_npu.multistream.ms_split import model_input_split_v1_mla_attn +from vllm_npu.torchair.utils import TorchairCommonAttentionMetadata +from vllm_npu.utils import is_enable_nz +from vllm_npu.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFATorchairBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA_TORCHAIR" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFATorchairMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFATorchairMetadataBuilder + + #NOTE: is that ok? + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["MLAAttentionImpl"]: + return AscendSFATorchairImpl + + +@dataclass +class AscendSFATorchairPrefillMetadata: + """ Prefill Specific Metadata for Ascend""" + + @dataclass + class TorchairChunkedContextMetadata: + # New for SFA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + + attn_mask: torch.Tensor + query_lens: list[int] # Check!! + seq_lens: list[int] # Check!! + context_lens: torch.Tensor + input_positions: torch.Tensor + query_start_loc: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_lens: int + sin: torch.Tensor + cos: torch.Tensor + chunked_context: Optional[TorchairChunkedContextMetadata] = None + + +@dataclass +class AscendSFATorchairDecodeMetadata: + # Input positions for rotrary embeddings since for SFA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + actual_seq_lengths_q: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendSFATorchairMetadata: + """Metadata for SFACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + block_tables: torch.Tensor + + # New for SFA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + query_lens: Optional[list[int]] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + decode: Optional[AscendSFATorchairDecodeMetadata] = None + prefill: Optional[AscendSFATorchairPrefillMetadata] = None + enable_dbo_across_dp: bool = False + is_prefill: bool = False + is_decode: bool = False + + def __post_init__(self): + pass + # supported_head_sizes = AscendSFABackend.get_supported_head_sizes() + # if self.head_dim is not None and self.head_dim \ + # not in supported_head_sizes: + # raise ValueError( + # f"Only {supported_head_sizes} are supported for head_dim,", + # f"received {self.head_dim}.") + + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendSFATorchairMetadata"]: + """Split metadata for multi-stream with AscendSFATorchairMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendSFATorchairMetadata, + ) + + +M = TypeVar("M", bound=AscendSFATorchairMetadata) + + +class AscendSFATorchairMetadataBuilder: + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFATorchairMetadata] = None): + self.metadata_cls: Optional[AscendSFATorchairMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFATorchairMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + scheduler_config = vllm_config.scheduler_config + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * self.model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 SFA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + # For torch air graph mode we treat spec decoding as decode. + if self.torchair_graph_enabled: + if num_tokens - num_spec_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + # For eager mode we treat spec decoding as chunked prefill. + else: + if num_tokens == 1: + decodes.append(i) + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + first_prefill = 0 + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + if decodes[num_decodes - i] >= num_decodes: + input_batch.swap_states(prefills[first_prefill], + decodes[num_decodes - i]) + first_prefill += 1 + modified_batch = True + else: + break + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + return modified_batch + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + max_blocks = self.max_blocks + + graph_block_tables = torch.zeros((num_seqs, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:, :max_blocks] + + def build_torchair_graph_dummy( + self, + common_attn_metadata: TorchairCommonAttentionMetadata, + ) -> AscendSFATorchairMetadata: + device = self.device + num_reqs = common_attn_metadata.num_reqs + block_table = torch.zeros((num_reqs, self.max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + num_tokens = num_reqs * common_attn_metadata.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = [0] * num_reqs + input_positions = torch.zeros(num_tokens, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_tokens, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + sin = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + cos = torch.ones(num_tokens, + 1, + 1, + self.rope_dim, + dtype=self.model_config.dtype, + device=device) + + if self.vllm_config.speculative_config is not None and\ + self.vllm_config.speculative_config.method == 'deepseek_mtp': + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 + # cumsum here. + # actual_seq_lengths_q = torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_tokens]).to(torch.int32).npu() + # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).to(torch.int32).npu() + actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( + torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req ############## + decode_metadata = AscendSFATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=1, + attn_mask=common_attn_metadata.spec_attn_mask, + # actual_seq_lengths_q=torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_reqs]).to(torch.int32).npu(), + actual_seq_lengths_q=actual_seq_lengths_q, + # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu(), + sin=sin, + cos=cos, + ) + return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=0, + attn_mask=common_attn_metadata.attn_mask, + attn_state=attn_state, + prefill=None, + decode=decode_metadata, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_tables=block_table, + is_prefill=False, + is_decode=True) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFATorchairMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + decode_threshold = common_attn_metadata.decode_token_per_req + else: + # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + 0].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + 0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + # check CPU operation here + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + query_lens = query_seq_lens_cpu[:num_reqs] + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] + num_computed_tokens_cpu = (seq_lens - query_lens) + + prefill_metadata = None + chunked_context_metadata = None + is_prefill = False + is_decode = False + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + tokens_start = num_decode_tokens + max_query_len = query_lens[tokens_start:].max().item() + max_seq_lens = seq_lens[tokens_start:].max().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + actual_query_lens = torch.tensor( + query_lens[tokens_start:], + dtype=torch.int32).npu() # int64->int32 + query_lens_prefill_sfa = torch.cumsum(actual_query_lens, + dim=0).to(torch.int32).npu() + seq_lens_prefill_sfa = torch.tensor(seq_lens, + dtype=torch.int32).npu() + prefill_metadata = AscendSFATorchairPrefillMetadata( + attn_mask=common_attn_metadata.attn_mask, + query_lens=query_lens_prefill_sfa, + seq_lens=seq_lens_prefill_sfa, + context_lens=seq_lens[tokens_start:], + input_positions=prefill_input_positions, + block_table=block_table[reqs_start:, ...], + max_query_len=max_query_len, + max_seq_lens=max_seq_lens, + query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, + ) + is_prefill = True + + decode_metadata = None + graph_pad_size = common_attn_metadata.graph_pad_size + use_torchair_graph = graph_pad_size != -1 + if num_decodes > 0: + # Check here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario + actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( + torch.int32).npu() + max_seq_lens = seq_lens[:num_decodes].max().item() + seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() + # input_positions = input_positions[:num_decode_tokens] + block_table = block_table[:num_decodes, ...] + num_token_pad_size = 0 + if use_torchair_graph and common_attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + num_reqs_pad_size = 0 + if graph_pad_size != 0: + pad_value = 0 + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size + else: + padded_seq_lens = seq_lens.tolist() + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)).npu() + seq_lens_list = padded_seq_lens + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_reqs_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_reqs + num_reqs_pad_size, block_table) + position_padding = torch.zeros(num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + + # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).npu() + # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu() + actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( + torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req + # MTP ignored + # actual_seq_lengths_q = self.pad_actual_seq_len_q( + # num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + # common_attn_metadata) + else: + seq_lens_list = seq_lens.tolist() + # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) + batch_size = num_decode_tokens + num_token_pad_size + if actual_seq_lengths_q[-1] != batch_size \ + and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + actual_seq_lengths_q[-1] = batch_size + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + padded_token_num = input_positions.shape[0] + actual_seq_lengths_q = torch.arange( + 1, + (padded_token_num // common_attn_metadata.decode_token_per_req) + + 1).to(torch.int32).npu( + ) * common_attn_metadata.decode_token_per_req + decode_metadata = AscendSFATorchairDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + is_decode = True + + return self.metadata_cls( # type: ignore + num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + prefill=prefill_metadata, + decode=decode_metadata, + query_start_loc=query_start_loc, + block_tables=block_table, + seq_lens=seq_lens, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_prefill=is_prefill, + is_decode=is_decode) + + def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + # return actual_seq_lengths_q + return torch.Tensor(actual_seq_lengths_q).to(torch.int32).npu() + + +class PrefillSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + + +class DecodeSFAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + # nope_cache: Optional[torch.Tensor] = None + # rope_cache: Optional[torch.Tensor] = None + topk_indices: Optional[torch.Tensor] = None + query_states: Optional[torch.Tensor] = None + key_states: Optional[torch.Tensor] = None + bsz: Optional[int] = None + + +class AscendSFATorchairImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.decoder_layer = kwargs.get('decoder_layer', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + if self.q_a_proj is not None: + self.q_b_proj = self.q_proj + else: + self.q_b_proj = None + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.weight_prefetch_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + if ascend_config.torchair_graph_config.enabled: + self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[ + 0] + self.actual_seq_length = torch.arange(1, self.graph_batch_size + + 1).to(torch.int32).npu() + vllm_config = get_current_vllm_config() + self.ring_mla_mask_size = 512 + self.prefill_mask = None + + # indexer param + self.dim = self.indexer.dim + self.n_heads: int = self.indexer.n_heads # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.index_topk: int = self.indexer.index_topk # 2048 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + self.softmax_scale = self.indexer.softmax_scale + + # Adapt torch air graph mode with spec decoding. + speculative_config = vllm_config.speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 + + self.cp_size = 1 + + if self.q_a_proj is not None: + self.prefix = self.q_a_proj.prefix + else: + self.prefix = 0 + self.debug_layer_idx = int(self.prefix.split(".")[2]) + self.layers = vllm_config.model_config.hf_config.num_hidden_layers + self.first_k_dense_replace = vllm_config.model_config.hf_config.first_k_dense_replace + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + if envs_ascend.vllm_npu_ENABLE_MLAPO: + self._process_weights_for_fused_mlapo(act_dtype) + + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): + kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data.clone() + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()), + dim=-1) + wd_qkv = wd_qkv.t().contiguous() + wd_qkv = transdata(wd_qkv, + block_size=(16, 32)).unsqueeze(0).contiguous() + if is_enable_nz(wd_qkv.dtype): + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + + kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, + self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat( + (kv_a_proj_deq_scl, self.q_a_proj.deq_scale.clone()), + dim=-1).contiguous() + + kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias.clone() + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, + self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat( + (kv_a_proj_qt_bias, self.q_a_proj.quant_bias.clone()), + dim=-1).contiguous() + + wu_q = self.q_proj.weight.data.clone() + wu_q = wu_q.t().reshape(self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + -1) + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) + wu_q = wu_q.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), + -1) + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() + if is_enable_nz(wu_q.dtype): + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + + qb_deq_scl = self.q_proj.deq_scale.data.clone() + qb_deq_scl = qb_deq_scl.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) + self.qb_deq_scl = qb_deq_scl.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + qb_qt_bias = self.q_proj.quant_bias.data.clone() + qb_qt_bias = qb_qt_bias.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) + self.qb_qt_bias = qb_qt_bias.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + self.gamma0 = self.decoder_layer.input_layernorm.weight.data + self.beta0 = self.decoder_layer.input_layernorm.bias.data + self.gamma1 = self.q_a_layernorm.weight.data + self.beta1 = self.q_a_layernorm.bias.data + self.gamma2 = self.kv_a_layernorm.weight.data + self.quant_scale0 = self.q_a_proj.input_scale.data + self.quant_offset0 = self.q_a_proj.input_offset.data + self.quant_scale1 = self.q_proj.input_scale.data + self.quant_offset1 = self.q_proj.input_offset.data + + def _sfa_decode_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + bsz = hidden_states.shape[0] + cos_shape = attn_metadata.decode.cos.shape + cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) + sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) + ctkv_scale = torch.tensor([1], + dtype=hidden_states.dtype, + device=hidden_states.device) + q_nope_scale = torch.tensor([1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + decode_q_nope, _, decode_q_pe, _ = torch_npu.npu_mla_process( + hidden_states, + self.gamma0, + self.beta0, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.kv_b_proj_w_k, + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping.flatten(), + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=ctkv_scale, + q_nope_scale=q_nope_scale, + cache_mode_opt="krope_ctkv", + quant_mode_opt="per_tensor_quant_asymm", + ) + decode_k_nope = kv_cache[0] + decode_k_pe = kv_cache[1] + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + hidden_states = self.decoder_layer.input_layernorm(hidden_states) + decode_kq = self.q_a_proj(hidden_states) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + + topk_indices = self.indexer_select(hidden_states, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + is_prefill=False) + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeSFAPreprocessResult( + q_nope=decode_q_nope, + q_pe=decode_q_pe, + topk_indices=topk_indices, + query_states=query_states, + key_states=key_states, + bsz=bsz, + ) + return decode_preprocess_res + + def forward( + self, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + if attn_metadata.prefill is not None: + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + bsz = 1 + + hidden_states_prefill = hidden_states + prefill_slot_mapping = attn_metadata.slot_mapping + prefill_kq = self.q_a_proj(hidden_states_prefill) # q down + prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm + prefill_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_prefill) # c_kv + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + prefill_kv_no_split = get_tp_group().all_gather( + prefill_kv_no_split, + 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + # prefill_q_c = q_c[ + # num_decode_tokens:num_actual_tokens] + + # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] + + # prefill_kv_no_split = kv_no_split[ + # num_decode_tokens:num_actual_tokens] + # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] + prefill_qr = prefill_q_c + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + prefill_qr = get_tp_group().all_gather( + prefill_qr, + 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + + prefill_q = self.q_b_proj(prefill_qr) + prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) + prefill_q_nope, prefill_q_pe = torch.split( + prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + prefill_q_nope = prefill_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + prefill_q_nope = (torch.matmul(prefill_q_nope, + self.kv_b_proj_w_k).transpose( + 1, + 0).view(-1, self.num_heads, + self.kv_lora_rank)) + prefill_q_pe = prefill_q_pe.unsqueeze(2) + + # stream2 kv + + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + cos_q, sin_q = cos, sin + + prefill_q_pe = torch_npu.npu_interleave_rope( + prefill_q_pe, cos_q, sin_q) # BNSD + prefill_q_pe = prefill_q_pe.squeeze(2) #BSH + # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? + + prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) + prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + prefill_latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + self.kv_a_layernorm.weight, + cos.view(-1, 1, 1, self.qk_rope_head_dim), + sin.view(-1, 1, 1, self.qk_rope_head_dim), + prefill_slot_mapping.to(torch.int64), + rope_cache, + nope_cache, + k_rope_scale=None, + c_kv_scale=None, + k_rope_offset=None, + c_kv_offset=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA") + + topk_indices = self.indexer_select(x=hidden_states_prefill, + qr=prefill_qr, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + is_prefill=True) + query_states = (prefill_q_nope, prefill_q_pe) + key_states = (prefill_k_nope, prefill_k_pe) + q_nope, q_pe = query_states + k_nope, k_rope = key_states + prefill_metadata = attn_metadata.prefill + + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=prefill_metadata.block_table, + actual_seq_lengths_query=prefill_metadata.query_lens, + actual_seq_lengths_kv=prefill_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul( + slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + # o_proj_input[num_decode_tokens:] = attn_output + output[...] = self.o_proj(attn_output, is_force_scatter=True) + return output + + elif attn_metadata.decode is not None: + if envs_ascend.vllm_npu_ENABLE_MLAPO: + prep_res = self._sfa_decode_preprocess(hidden_states, kv_cache, + attn_metadata, + need_gather_q_kv) + q_nope, q_pe = prep_res.query_states + k_nope, k_rope = prep_res.key_states + topk_indices = prep_res.topk_indices + else: + q_len = 1 + hidden_states_decode = hidden_states + decode_kq = self.q_a_proj(hidden_states_decode) # q down + decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm + decode_kv_no_split = self.kv_a_proj_with_mqa( + hidden_states_decode) # c_kv + # self.actual_seq_length = torch.arange(1,self.graph_batch_size+1).to(torch.int32).npu() + + # decode_q_c = q_c[:num_decode_tokens] + decode_slot_mapping = attn_metadata.slot_mapping + + decode_q = self.q_b_proj(decode_q_c) + bsz, _ = decode_q.shape + decode_q = decode_q.view(bsz, self.num_heads, 1, + self.qk_head_dim) # [16, 16, 1, 192] + decode_q_nope, decode_q_pe = torch.split( + decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) # [..., 128/64] + decode_q_nope = decode_q_nope.view( + -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) + decode_q_nope = (torch.matmul( + decode_q_nope, self.kv_b_proj_w_k).transpose(1, 0).view( + bsz, q_len, self.num_heads, self.kv_lora_rank)) + + # stream2 kv + key_cache = kv_cache[0] + value_cache = kv_cache[1] + cos = attn_metadata.decode.cos # [16, 1, 1, 64] + sin = attn_metadata.decode.sin + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze( + 1) + decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + decode_kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + decode_slot_mapping.to(torch.int64), + value_cache, + key_cache, + c_kv_scale=None, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode='PA') # adapter NZ + # nz_block_size = 16 + # KVCACHE_NZ_DIM = 16 + # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) + # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) + decode_q_pe = torch_npu.npu_interleave_rope( + decode_q_pe, cos, sin) # BNSD + + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, + self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + + topk_indices = self.indexer_select(hidden_states_decode, + decode_q_c, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + is_prefill=False) + + query_states = (decode_q_nope, decode_q_pe) + key_states = (decode_k_nope, decode_k_rope) + q_nope, q_pe = query_states + k_nope, k_rope = key_states + + decode_metadata = attn_metadata.decode + slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.decode.block_table, + actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=decode_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + slc_fa_fusion = slc_fa_fusion.squeeze(1) + slc_fa_fusion = slc_fa_fusion.transpose(0, 1) + + # input shape [N//attn_tp_size, T(bs*q_len), D] + # output shape [T(bs*q_len), N//attn_tp_size, D] + attn_output = torch.matmul( + slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( + -1, self.num_heads * self.v_head_dim) + output[...] = self.o_proj(attn_output) + return output + + def mla_epilog(self, + attn_output: torch.Tensor = None, + absorb: bool = False): + # TODO: + attn_output = self.o_proj(attn_output) + return attn_output + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + is_prefill: bool = True, + ): + if attn_metadata.prefill is not None: + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + elif attn_metadata.decode is not None: + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + k_proj = get_tp_group().all_gather( + k_proj, 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights = self.weights_proj(x) + if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + weights = get_tp_group().all_gather( + weights, 0)[attn_metadata.num_decode_tokens:attn_metadata. + num_actual_tokens] + actual_seq_lengths_query = None + actual_seq_lengths_key = None + block_table = None + if attn_metadata.prefill is not None: + actual_seq_lengths_query = attn_metadata.prefill.query_lens + actual_seq_lengths_key = attn_metadata.prefill.seq_lens + + block_table = attn_metadata.prefill.block_table + elif attn_metadata.decode is not None: + actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q + actual_seq_lengths_key = attn_metadata.decode.seq_lens.to( + torch.int32) + + block_table = attn_metadata.decode.block_table + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def trans_rope_weight(weight, rope_dim): + weight_1 = weight[..., -rope_dim::2, :].contiguous() + weight_2 = weight[..., -rope_dim + 1::2, :].contiguous() + weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2) + + return weight.contiguous() + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad))) + nz_mat = torch.permute( + torch.reshape( + nd_mat, + (r // block_size[0], block_size[0], c // block_size[1], + block_size[1]), + ), + [2, 0, 1, 3], + ) + nz_mat = torch.reshape( + nz_mat, + (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + return nz_mat diff --git a/vllm_npu/torchair/torchair_worker.py b/vllm_npu/torchair/torchair_worker.py new file mode 100644 index 0000000..916b67a --- /dev/null +++ b/vllm_npu/torchair/torchair_worker.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from vllm.logger import logger + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.torchair.torchair_model_runner import NPUTorchairModelRunner +from vllm_npu.torchair.utils import (check_kv_cache_bytes_cache_exist, + delete_torchair_cache_file, + read_kv_cache_bytes_from_file) +from vllm_npu.worker.worker_v1 import NPUWorker + + +class NPUTorchairWorker(NPUWorker): + """Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class.""" + + def determine_available_memory(self) -> int: + """Override determine_available_memory to use cached torchair kv_cache_bytes.""" + + available_kv_cache_memory = super().determine_available_memory() + ascend_config = get_ascend_config() + if ascend_config.enable_shared_expert_dp: + return available_kv_cache_memory + if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: + if check_kv_cache_bytes_cache_exist(): + old_kv_cache_bytes = read_kv_cache_bytes_from_file( + torch.distributed.get_rank()) + if 0 < old_kv_cache_bytes <= available_kv_cache_memory: + logger.info( + f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" + ) + self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes + return old_kv_cache_bytes + else: + logger.info( + "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" + ) + delete_torchair_cache_file() + bytes_floating_tolerance = 1024 * 1024 * envs_ascend.vllm_npu_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE + available_kv_cache_memory -= bytes_floating_tolerance + logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") + self.model_runner.new_kv_cache_bytes = available_kv_cache_memory + return available_kv_cache_memory + + def init_device(self): + """Override init_device to init torchair model runner""" + device = self._init_device() + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = NPUTorchairModelRunner(self.vllm_config, device) diff --git a/vllm_npu/torchair/utils.py b/vllm_npu/torchair/utils.py new file mode 100644 index 0000000..4ac9db2 --- /dev/null +++ b/vllm_npu/torchair/utils.py @@ -0,0 +1,275 @@ +import fcntl +import os +import shutil +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass + +import torch +import torch_npu +from torchair.scope import super_kernel as _super_kernel + +try: + # Recent release of torchair has moved these ops to `.scope`. + from torchair.scope import npu_stream_switch as _npu_stream_switch + from torchair.scope import npu_wait_tensor as _npu_wait_tensor +except ImportError: + from torchair.ops import NpuStreamSwitch as _npu_stream_switch + from torchair.ops import npu_wait_tensor as _npu_wait_tensor + +import vllm_npu.envs as envs_ascend +from vllm_npu.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz + +KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes" +KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes" +TORCHAIR_CACHE_PATH_NAME = ".torchair_cache" +TORCHAIR_CACHE_DIR = os.path.join( + os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME) + + +@dataclass +class TorchairCommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + num_reqs: int + """Number of requests""" + + num_actual_tokens: int + """Total number of tokens in batch""" + + decode_token_per_req: int + + actual_seq_lengths_q: list[int] + + attn_mask: torch.Tensor = None + + spec_attn_mask: torch.Tensor = None + + graph_pad_size: int = -1 + + +@contextmanager +def _file_lock(file_descriptor, lock_type): + fcntl.flock(file_descriptor, lock_type) + try: + yield + finally: + fcntl.flock(file_descriptor, fcntl.LOCK_UN) + + +def _get_torchair_current_work_dir(file_name=None): + if file_name is None: + return TORCHAIR_CACHE_DIR + return os.path.join(TORCHAIR_CACHE_DIR, file_name) + + +def check_torchair_cache_exist(): + res = False + torch_air_abs_path = _get_torchair_current_work_dir() + if os.path.exists(torch_air_abs_path): + file_list = os.listdir(torch_air_abs_path) + if len(file_list) != 0: + res = True + return res + + +def check_kv_cache_bytes_cache_exist(): + res = False + kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir( + KV_CACHE_BYTES_CACHE_PATH_NAME) + if os.path.exists(kv_cache_bytes_cache_abs_path): + file_list = os.listdir(kv_cache_bytes_cache_abs_path) + if len(file_list) != 0: + res = True + return res + + +def read_kv_cache_bytes_from_file(rank) -> int: + kv_cache_bytes = -1 + kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir( + KV_CACHE_BYTES_CACHE_PATH_NAME) + kv_cache_bytes_file = os.path.join( + kv_cache_bytes_cache_abs_path, + f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}") + with open(kv_cache_bytes_file, "r", encoding="utf-8") as f: + with _file_lock(f, fcntl.LOCK_SH): + kv_cache_bytes = int(f.readline()) + return kv_cache_bytes + + +def write_kv_cache_bytes_to_file(rank, kv_cache_bytes): + kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir( + KV_CACHE_BYTES_CACHE_PATH_NAME) + os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True) + kv_cache_bytes_file = os.path.join( + kv_cache_bytes_cache_abs_path, + f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}") + with open(kv_cache_bytes_file, "w", encoding="utf-8") as f: + with _file_lock(f, fcntl.LOCK_EX): + f.write(f"{kv_cache_bytes}") + + +def delete_torchair_cache_file(): + torch_air_abs_path = _get_torchair_current_work_dir() + try: + shutil.rmtree(torch_air_abs_path) + except FileNotFoundError: + pass + + +def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True): + return _npu_stream_switch(tag, priority) if enabled else nullcontext() + + +def npu_wait_tensor(self: torch.Tensor, + dependency: torch.Tensor, + *, + enabled: bool = True): + return _npu_wait_tensor(self, dependency) if enabled else self + + +def converting_weight_acl_format(model, format): + # currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ + # in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ + # is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this + # conversion when using torchair graph mode on 300I Duo platform. + # TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant + # accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode. + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + for module in model.modules(): + if isinstance(module, FusedMoE): + if torch_npu.get_npu_format(module.w13_weight.data) == format: + return + if format == ACL_FORMAT_FRACTAL_NZ \ + and not is_enable_nz(module.w13_weight.data.dtype): + return + module.w13_weight.data = torch_npu.npu_format_cast( + module.w13_weight.data, format) + module.w2_weight.data = torch_npu.npu_format_cast( + module.w2_weight.data, format) + + +def register_torchair_model(): + from vllm import ModelRegistry + + ModelRegistry.register_model( + "DeepSeekMTPModel", + "vllm_npu.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP" + ) + + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_npu.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM" + ) + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_npu.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" + ) + + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_npu.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" + ) + + ModelRegistry.register_model( + "Qwen2ForCausalLM", + "vllm_npu.torchair.models.qwen2:CustomQwen2ForCausalLM") + + ModelRegistry.register_model( + "Qwen3MoeForCausalLM", + "vllm_npu.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM") + + ModelRegistry.register_model( + "PanguProMoEForCausalLM", + "vllm_npu.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" + ) + + +def torchair_quant_method_register(): + from vllm_npu.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP + from vllm_npu.torchair.quantization.torchair_w4a8_dynamic import ( + TorchairAscendW4A8DynamicFusedMoEMethod, + TorchairAscendW4A8DynamicLinearMethod) + from vllm_npu.torchair.quantization.torchair_w8a8_dynamic import ( + TorchairAscendW8A8DynamicFusedMoEMethod, + TorchairAscendW8A8DynamicLinearMethod) + + ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ + "linear"] = TorchairAscendW8A8DynamicLinearMethod + ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ + "moe"] = TorchairAscendW8A8DynamicFusedMoEMethod + ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ + "linear"] = TorchairAscendW4A8DynamicLinearMethod + ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ + "moe"] = TorchairAscendW4A8DynamicFusedMoEMethod + + +def torchair_ops_patch(): + from vllm_npu.ops.activation import AscendSiluAndMul + from vllm_npu.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm + from vllm_npu.ops.rotary_embedding import ( + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + from vllm_npu.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding + from vllm_npu.torchair.ops import (torchair_activation, + torchair_layernorm) + from vllm_npu.torchair.ops.torchair_rotary_embedding import ( + deepseek_rope_init_func, native_rope_deepseek_forward, + qwen_rope_init_func, rope_forward) + from vllm_npu.torchair.ops.torchair_vocab_parallel_embedding import \ + vocab_embedding_forward + + AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign] + AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign] + + AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign] + AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign] + + AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign] + AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] + + AscendQuantRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign] + AscendQuantRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] + + AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] + AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign] + + +def super_kernel(prefix: str, option: str, enabled: bool = True): + return _super_kernel(prefix, option) if enabled else nullcontext() + + +# TODO(ttanzhiqiang): rm_router_logits +# dp>1 will trigger +# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. +def get_rm_router_logits_state(ep_size: int, dp_size: int, + is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if dp_size > 1: + if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return True + elif ep_size == 1 and is_deepseek_v3_r1: + return True + return False + + +# TODO(ttanzhiqiang): all_reduce merge +# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce +# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. +def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return True + elif ep_size == 1 and is_deepseek_v3_r1: + return True + return False diff --git a/vllm_npu/utils.py b/vllm_npu/utils.py new file mode 100644 index 0000000..f4150e8 --- /dev/null +++ b/vllm_npu/utils.py @@ -0,0 +1,830 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/worker.py +# + +import atexit +import functools +import math +import os +from contextlib import contextmanager, nullcontext +from enum import Enum +from threading import Lock +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +import torch_npu # noqa: F401 +from packaging.version import InvalidVersion, Version +from torch_npu.npu.streams import Event +from vllm.logger import logger + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +ASCEND_QUANTIZATION_METHOD = "ascend" +SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] +REGISTERED_ASCEND_OPS = {} + +ACL_FORMAT_FRACTAL_ND = 2 +ACL_FORMAT_FRACTAL_NZ = 29 + +_CUSTOM_OP_ENABLED = None +_IS_310P = None +_SLEEP_MODE_ENABLED = None +_CURRENT_STREAM = None +_PREFETCH_STREAM = None +_SHARED_EXPERTS_COMPUTE_STREAM = None +_ASCEND_CUSTOMOP_IS_REIGISTERED = False +_DEFAULT_BUFFER_SIZE = 200 +_MIN_DP_BUFFER_SIZE = 50 +_IS_MOE_MODEL = None +_IS_VL_MODEL = None +_ENABLE_SP = None +_HAS_LAYER_IDX = None +_ENABLE_NZ = None +_IS_EAGLE_MODE = None + + +def is_310p(): + global _IS_310P + if _IS_310P is None: + from vllm_npu import _build_info # type: ignore + _IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p") + return _IS_310P + + +def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, + vllm_config: Optional[VllmConfig] = None) -> bool: + global _ENABLE_NZ, _IS_EAGLE_MODE + if _ENABLE_NZ is None: + if not vllm_config: + raise ValueError( + "vllm_config must be provided when _ENABLE_NZ is None") + _ENABLE_NZ = envs_ascend.vllm_npu_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" + + _IS_EAGLE_MODE = (vllm_config.speculative_config is not None + and getattr(vllm_config.speculative_config, 'method', + None) in ("eagle", "eagle3")) + + if dtype in [torch.float16, torch.bfloat16, torch.float32]: + return _ENABLE_NZ if _IS_EAGLE_MODE else False + return _ENABLE_NZ + + +def sleep_mode_enabled(): + global _SLEEP_MODE_ENABLED + if _SLEEP_MODE_ENABLED is None: + from vllm_npu import _build_info # type: ignore + _SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__ + return _SLEEP_MODE_ENABLED + + +def _round_up(x: int, align: int): + # round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc. + # input: 15, 16 -> output: 16 + # input: 17, 16 -> output: 32 + # input: 30, 16 -> output: 32 + # input: 33, 16 -> output: 48 + # ... + return (x + align - 1) // align * align + + +def _custom_pad(x, pad_dims): + # pad the input tensor to the shape of pad_dims + # input: (13, 30), pad_dims: [0, 2, 0, 3] + # output: (16, 32) + return torch.nn.functional.pad(x, pad_dims) + + +def _custom_reshape(x, target_shape): + # reshape the input tensor to the shape of target_shape + # input: (16, 32), target_shape: [1, 16, 2, 16] + # output: (1, 16, 2, 16) + return x.reshape(target_shape) + + +def _custom_transpose(x, dim1, dim2): + # transpose the input tensor + # input: (1, 16, 2, 16), dim1: 1, dim2: 2 + # output: (1, 2, 16, 16) + return x.transpose(dim1, dim2) + + +def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor: + # in_tensor: (13, 30) + aux_dims = [1, 0, 0, 16] + # aux_dims[1]: 16 + aux_dims[1] = _round_up(in_tensor.size(0), 16) + # aux_dims[2]: 2 + aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16 + + # after: aux_dims: [1, 16, 2, 16] + + pad_dims = [0, 0, 0, 0] + # pad_dims[1]: 2 + pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1) + # pad_dims[3]: 3 + pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0) + + # after: pad_dims: [0, 2, 0, 3] + + # return: (1, 2, 16, 16) + return _custom_transpose( + _custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, + 2).contiguous() + + +def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor: + num_tokens = mask_tensor.shape[0] + max_seq_len = mask_tensor.shape[1] + + tokens_pad = (num_tokens + 15) // 16 * 16 + max_seq_len_pad = (max_seq_len + 15) // 16 * 16 + + mask_tensor_pad = \ + torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device) + mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor + mask = mask_tensor_pad.reshape( + (1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3) + return mask + + +def aligned_16(tensor: torch.Tensor): + """Aligned tensor for 310P""" + + # Get the size of the current 0th dimension + n = tensor.size(0) + + # Calculate the aligned size + n_aligned = ((n + 15) // 16) * 16 + + # If already aligned, return the original tensor + if n == n_aligned: + return tensor + + # Create a new tensor with shape (n_aligned, H, W) and fill it with zeros + new_tensor = torch.zeros(n_aligned, + *tensor.shape[1:], + dtype=tensor.dtype, + device=tensor.device) + + # Copy the original tensor to the first N positions of the new tensor + new_tensor[:n] = tensor + + return new_tensor + + +def try_register_lib(lib_name: str, lib_info: str = ""): + import importlib + import importlib.util + try: + module_spec = importlib.util.find_spec(lib_name) + if module_spec is not None: + importlib.import_module(lib_name) + if lib_info: + logger.info(lib_info) + except Exception: + pass + + +def enable_custom_op(): + """ + Enable lazy init for vllm_npu_C to avoid early initialization of CANN's RTS component. + Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device(). + """ + global _CUSTOM_OP_ENABLED + if _CUSTOM_OP_ENABLED is not None: + return _CUSTOM_OP_ENABLED + try: + # isort: off + # register custom ops into torch_library here + import vllm_npu.vllm_npu_C # type: ignore # noqa: F401 + # register the meta implementation for custom kernel if necessary + import vllm_npu.meta_registration # type: ignore # noqa: F401 + # isort: on + _CUSTOM_OP_ENABLED = True + except ImportError: + _CUSTOM_OP_ENABLED = False + logger.warning( + "Warning: Failed to register custom ops, all custom ops will be disabled" + ) + return _CUSTOM_OP_ENABLED + + +def find_hccl_library() -> str: + """ + We either use the library file specified by the `HCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libhccl.so` can be + found by `ctypes` automatically. + """ + so_file = envs_ascend.HCCL_SO_PATH + + # manually load the hccl library + if so_file: + logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cann is not None: + so_file = "libhccl.so" + else: + raise ValueError("HCCL only supports Ascend NPU backends.") + logger.info("Found hccl from library %s", so_file) + return so_file + + +def current_stream() -> torch.npu.Stream: + """ + replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.npu.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.npu.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.npu.current_stream()`. + + """ + global _CURRENT_STREAM + if _CURRENT_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _CURRENT_STREAM = torch.npu.current_stream() + return _CURRENT_STREAM + + +def prefetch_stream() -> torch.npu.Stream: + global _PREFETCH_STREAM + if _PREFETCH_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _PREFETCH_STREAM = torch_npu.npu.Stream() + return _PREFETCH_STREAM + + +def shared_experts_compute_stream() -> torch.npu.Stream: + global _SHARED_EXPERTS_COMPUTE_STREAM + if _SHARED_EXPERTS_COMPUTE_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _SHARED_EXPERTS_COMPUTE_STREAM = torch_npu.npu.Stream() + return _SHARED_EXPERTS_COMPUTE_STREAM + + +def adapt_patch(is_global_patch: bool = False): + if is_global_patch: + from vllm_npu.patch import platform # noqa: F401 + else: + from vllm_npu.patch import worker # noqa: F401 + + +@functools.cache +def vllm_version_is(target_vllm_version: str): + if envs_ascend.VLLM_VERSION is not None: + vllm_version = envs_ascend.VLLM_VERSION + else: + import vllm + vllm_version = vllm.__version__ + try: + return Version(vllm_version) == Version(target_vllm_version) + except InvalidVersion: + raise ValueError( + f"Invalid vllm version {vllm_version} found. A dev version of vllm " + "is installed probably. Set the environment variable VLLM_VERSION " + "to control it by hand. And please make sure the value follows the " + "format of x.y.z.") + + +def get_max_hidden_layers(hf_config) -> int: + cfg_dict = hf_config.to_dict() + layer_counts = [] + + def _rec_find(d): + if isinstance(d, dict): + for k, v in d.items(): + if k == "num_hidden_layers" and isinstance(v, int): + layer_counts.append(v) + else: + _rec_find(v) + + _rec_find(cfg_dict) + if not layer_counts: + raise ValueError("Not found num_hidden_layers in model config.") + return max(layer_counts) + + +def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool: + """ + Check whether it is vLLM default capture sizes. + """ + + cuda_graph_sizes = vllm_config.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + default_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + + if sorted(default_size_capture_list, reverse=True) == \ + vllm_config.compilation_config.cudagraph_capture_sizes: + return True + + return False + + +def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: + """ + Update ACL graph default capture sizes, so that new sizes + are more friendly to ascend ops && hardware. + """ + + if vllm_config.model_config is None or \ + vllm_config.model_config.enforce_eager or \ + not _is_default_capture_sizes(vllm_config): + return + + # modify the default capture_sizes for Qwen3-MoE models on dp settings. + # this is mainly because performance of _npu_paged_attention might degrades + # on special shapes. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + if vllm_config.model_config and vllm_config.model_config.hf_config.model_type == "qwen3_moe" \ + and vllm_config.parallel_config.tensor_parallel_size == 1 \ + and vllm_config.parallel_config.data_parallel_size > 1 : + max_capture_size = vllm_config.scheduler_config.cuda_graph_sizes[0] + new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [ + i for i in range(24, max_capture_size + 1, 8) + ] + + vllm_config.compilation_config.cudagraph_capture_sizes = new_cudagraph_capture_sizes + vllm_config.compilation_config.init_with_cudagraph_sizes( + new_cudagraph_capture_sizes) + + +def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: + """Update ACL graph capture sizes based on hardware limitations""" + # NOTE: Currently, we can only capture 1800 graphs at most, + # due to the limitation of ACL graph. This number is bounded by + # the number of streams, which is 2048, we save 248 streams + # as a buffer. + # Maximum number of graphs that can be captured by ACL Graph + # TODO: Find out whether we need to solve allreduce function + MAX_CAPTURE_SIZE = 1800 + + # Store original configuration and temporarily clear it + compilation_config = vllm_config.compilation_config + original_sizes, compilation_config.cudagraph_capture_sizes = \ + compilation_config.cudagraph_capture_sizes, None + + # Calculate parallel configuration factor + hf_config = vllm_config.model_config.hf_config + if hasattr(hf_config, 'num_hidden_layers'): + num_hidden_layers = hf_config.num_hidden_layers + else: + num_hidden_layers = get_max_hidden_layers(hf_config) + parallel_config = vllm_config.parallel_config + + # Calculate maximum supported batch sizes considering model architecture + resources_per_graph = num_hidden_layers + 1 + if vllm_config.speculative_config is not None: + draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config + resources_per_graph += draft_model_hf_config.num_hidden_layers + 1 + + # TODO: Find out whether we need to take into account the pp_size + num_comm_groups = sum(size > 1 for size in [ + parallel_config.data_parallel_size, + parallel_config.tensor_parallel_size, + ]) + + if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': + # TODO: Find out whether we need to take into account the pp_size + parallel_factor = 1 + num_comm_groups + int( + parallel_config.enable_expert_parallel) + int( + vllm_config.additional_config.get( + "multistream_overlap_shared_expert", False)) + if is_moe_model(vllm_config): + parallel_factor += (parallel_config.data_parallel_size > 1) + else: + # When AIV mode is enabled, the allreduce operator of the dense + # layer model will occupy additional streams, which are buffered here. + MAX_CAPTURE_SIZE = MAX_CAPTURE_SIZE - parallel_factor * resources_per_graph + + # Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device + # Assume the following case: + # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, + # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 + max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / + resources_per_graph / parallel_factor) + logger.info( + "Calculated maximum supported batch sizes for ACL graph: %s", + max_num_batch_sizes) + else: + # The above describes an empirical formula applicable to the A2 hardware. + # Under this configuration, HCCL employs the FFTS+ method for execution unfolding, + # which adds only 1 concurrent stream without consuming collective communication execution unfolding streams. + # On A3 hardware, HCCL defaults to the AICPU method. + # This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case). + # Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes. + # Therefore, the calculation formula has been modified as follows: + # Assume the following case: + # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, + # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 + max_num_batch_sizes = math.floor( + (MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / + (1 + num_comm_groups * 2)) + logger.info( + "Calculated maximum supported batch sizes for ACL graph: %s", + max_num_batch_sizes) + logger.warning( + "Currently, communication is performed using FFTS+ method, which reduces " + "the number of available streams and, as a result, limits the range of runtime " + "shapes that can be handled. To both improve communication performance and " + "increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV." + ) + + # If original sizes exceed maximum, sample a representative subset + if max_num_batch_sizes < len(original_sizes): + # Sample uniformly from original sizes + step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1) + indices = [round(i * step) for i in range(max_num_batch_sizes)] + + # Ensure first and last elements are preserved + indices[0], indices[-1] = 0, len(original_sizes) - 1 + + sampled_sizes = [original_sizes[i] for i in indices] + compilation_config.init_with_cudagraph_sizes(sampled_sizes) + + logger.info( + "Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes", + vllm_config.model_config.architectures[0], + num_hidden_layers, + len(original_sizes), + len(compilation_config. + cudagraph_capture_sizes # type: ignore[arg-type] + )) + else: + # No adjustment needed + compilation_config.cudagraph_capture_sizes = original_sizes + logger.info( + "No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes", + vllm_config.model_config.architectures[0], num_hidden_layers, + len(original_sizes)) + + # default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario + # the maximum size cudagraph_capture_sizes[0] should be greater or equal than + # (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode + if vllm_config.speculative_config is not None and \ + vllm_config.speculative_config.num_speculative_tokens > 1: + num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + original_sizes, compilation_config.cudagraph_capture_sizes = \ + compilation_config.cudagraph_capture_sizes, None + assert len(original_sizes) > 0 + if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: + enlarged_sizes = [(num_speculative_tokens + 1) * size + for size in original_sizes] + compilation_config.init_with_cudagraph_sizes(enlarged_sizes) + logger.info( + "Adjusted ACL graphs: %s → %s for speculative decoding", + original_sizes, enlarged_sizes) + else: + compilation_config.cudagraph_capture_sizes = original_sizes + + +# TODO(wxy): Move to ops module +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) + + +class ProfileExecuteDuration: + _instance = None + _observations: List[Tuple[str, Event, Event]] = [] + _lock = Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + atexit.register(cls._instance.destroy) + return cls._instance + + def destroy(self): + with self._lock: + self._observations.clear() + + @contextmanager + def capture_async(self, duration_tag: str): + if not envs_ascend.vllm_npu_MODEL_EXECUTE_TIME_OBSERVE: + yield + return + + observe_start = Event(enable_timing=True) + observe_start.record() + try: + yield + finally: + observe_end = Event(enable_timing=True) + observe_end.record() + with self._lock: + self._observations.append( + (duration_tag, observe_start, observe_end)) + + def pop_captured_sync(self) -> dict: + """Pop and synchronize all events in the observation list""" + durations: dict[str, float] = {} + if not envs_ascend.vllm_npu_MODEL_EXECUTE_TIME_OBSERVE: + return durations + + while self._observations: + with self._lock: + tag, observe_start, observe_end = self._observations.pop() + observe_end.synchronize() + durations[tag] = observe_start.elapsed_time(observe_end) + + return durations + + +def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): + """Register Ascend CustomOP + + NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, + and ensure this will execute after model config is initilazed. + """ + global _ASCEND_CUSTOMOP_IS_REIGISTERED + if _ASCEND_CUSTOMOP_IS_REIGISTERED: + return + from vllm.model_executor.custom_op import CustomOp + + from vllm_npu.models.layers.mla import AscendMultiHeadLatentAttention + from vllm_npu.ops.activation import AscendQuickGELU, AscendSiluAndMul + from vllm_npu.ops.common_fused_moe import (AscendFusedMoE, + AscendSharedFusedMoE) + from vllm_npu.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm + from vllm_npu.ops.linear import (AscendColumnParallelLinear, + AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, + AscendReplicatedLinear, + AscendRowParallelLinear) + from vllm_npu.ops.rotary_embedding import ( + AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding, + AscendRotaryEmbedding, AscendYaRNRotaryEmbedding) + from vllm_npu.ops.vocab_parallel_embedding import ( + AscendLogitsProcessor, AscendParallelLMHead, + AscendVocabParallelEmbedding) + + global REGISTERED_ASCEND_OPS + REGISTERED_ASCEND_OPS = { + "QuickGELU": AscendQuickGELU, + "SiluAndMul": AscendSiluAndMul, + "RotaryEmbedding": AscendRotaryEmbedding, + "MRotaryEmbedding": AscendMRotaryEmbedding, + "ColumnParallelLinear": AscendColumnParallelLinear, + "RowParallelLinear": AscendRowParallelLinear, + "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding, + "MergedColumnParallelLinear": AscendMergedColumnParallelLinear, + "QKVParallelLinear": AscendQKVParallelLinear, + "ReplicatedLinear": AscendReplicatedLinear, + "DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding, + "VocabParallelEmbedding": AscendVocabParallelEmbedding, + "ParallelLMHead": AscendParallelLMHead, + "LogitsProcessor": AscendLogitsProcessor, + "RMSNorm": AscendRMSNorm, + "GemmaRMSNorm": AscendGemmaRMSNorm, + "FusedMoE": AscendFusedMoE, + "SharedFusedMoE": AscendSharedFusedMoE, + "MultiHeadLatentAttention": AscendMultiHeadLatentAttention, + } + + for name, op_cls in REGISTERED_ASCEND_OPS.items(): + CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) + + # NOTE: Keep this at last to ensure all custom actions are registered + _ASCEND_CUSTOMOP_IS_REIGISTERED = True + + +# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN. +# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does. +class AscendSocVersion(Enum): + A2 = 0 + A3 = 1 + UNDEFINED = 2 + + +_ascend_soc_version = None + + +def init_ascend_soc_version(): + soc_version = torch_npu.npu.get_soc_version() + global _ascend_soc_version + if 220 <= soc_version <= 225: + _ascend_soc_version = AscendSocVersion.A2 + elif 250 <= soc_version <= 255: + _ascend_soc_version = AscendSocVersion.A3 + else: + _ascend_soc_version = AscendSocVersion.UNDEFINED + + +def get_ascend_soc_version(): + global _ascend_soc_version + assert _ascend_soc_version is not None + return _ascend_soc_version + + +def lmhead_tp_enable() -> bool: + return get_ascend_config().lmhead_tensor_parallel_size is not None + + +def oproj_tp_enable() -> bool: + return get_ascend_config().oproj_tensor_parallel_size is not None + + +def mlp_tp_enable() -> bool: + return envs_ascend.vllm_npu_ENABLE_MLP_OPTIMIZE + + +def matmul_allreduce_enable() -> bool: + return envs_ascend.vllm_npu_ENABLE_MATMUL_ALLREDUCE + + +def dense_optim_enable() -> bool: + return envs_ascend.vllm_npu_ENABLE_DENSE_OPTIMIZE + + +def enable_sp(vllm_config=None) -> bool: + global _ENABLE_SP + if _ENABLE_SP is None: + if vllm_config is None: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + _ENABLE_SP = ( + vllm_config.compilation_config.pass_config. + enable_sequence_parallelism + or envs_ascend.vllm_npu_ENABLE_FLASHCOMM1 + # Flash comm 1 should be enabled by env vllm_npu_ENABLE_FLASHCOMM1 + # We retain the env vllm_npu_ENABLE_FLASHCOMM here for backward compatibility. + or bool(int(os.getenv("vllm_npu_ENABLE_FLASHCOMM", '0')))) + + return _ENABLE_SP + + +# TODO remove it after vllm has this func +def shared_expert_dp_enabled() -> bool: + return get_ascend_config().enable_shared_expert_dp or enable_sp() + + +def is_moe_model(vllm_config: VllmConfig): + """Checks if the model is a MoE model by config""" + global _IS_MOE_MODEL + if _IS_MOE_MODEL is None: + model_configs = vllm_config.model_config.hf_config.to_dict() + _IS_MOE_MODEL = _is_contain_expert(model_configs) + return _IS_MOE_MODEL + + +def _is_contain_expert(config: Any): + if isinstance(config, dict): + for k, v in config.items(): + if "expert" in str(k): + return True + if _is_contain_expert(v): + return True + return False + + +def is_vl_model(vllm_config: VllmConfig): + """Checks if the model is a VL model by config""" + global _IS_VL_MODEL + if _IS_VL_MODEL is None and vllm_config.model_config: + model_configs = vllm_config.model_config.hf_config.to_dict() + _IS_VL_MODEL = "VL" in model_configs["architectures"][0] + return _IS_VL_MODEL + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C_ascend.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + + This function should be used in the following scenario: + When a tensor is created during graph capture, and it's held by a method + that's not part of the graph, we don't really need to store it, but we + **do need** its buffer pointer. If we don't handle this, it cannot + be garbage collected, leading to a memory leak. To avoid this, + we should create a weak reference to the tensor. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors") + + +def npu_stream_switch(target_stream: torch.npu.Stream, + *, + enabled: bool = True): + """ + Switch to the target stream if enabled is True. + Otherwise, do nothing. + """ + if not enabled: + return nullcontext() + assert target_stream is not None + return torch.npu.stream(target_stream) + + +def create_hccl_pg_options(group_name: str): + options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options() + hccl_config = get_hccl_config_for_pg_options(group_name) + if hccl_config is not None: + options.hccl_config = hccl_config + return options + + +def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: + """ + Get HCCL process group options for the given communication group name. + + Args: + group_name: Name of the communication group + + Returns: + HCCL pg_options or None for mc2 group + """ + # FIXME: Current mc2 operators only perform communication space partitioning + # based on HCCL_BUFFSIZE configuration. Using pg_options with mc2 group would + # result in memory misalignment problems. + if group_name and "mc2" in group_name: + return None + hccl_config_map = { + "dp": { + "hccl_buffer_size": calculate_dp_buffer_size() + }, + } + return hccl_config_map.get(group_name, get_default_buffer_config()) + + +def get_default_buffer_config() -> dict: + return {"hccl_buffer_size": _DEFAULT_BUFFER_SIZE} + + +def calculate_dp_buffer_size() -> int: + """ + formula of dp buffer size: + dp_size + 2 (flags: with_prefill and enable_dbo) + """ + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + dp_size = vllm_config.parallel_config.data_parallel_size + int32_size = torch.iinfo(torch.int32).bits // 8 + dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024)) + return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) + + +# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 +# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and +# significantly improve communication performance of MC2 ops dispatch/combine. +def is_hierarchical_communication_enabled(): + return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" + and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1") + + +def has_layer_idx(model_instance: torch.nn.Module) -> bool: + if model_instance is None: + return False + + global _HAS_LAYER_IDX + if _HAS_LAYER_IDX is None: + _HAS_LAYER_IDX = hasattr(model_instance, "model") and \ + hasattr(model_instance.model, "start_layer") + return _HAS_LAYER_IDX diff --git a/vllm_npu/worker/__init__.py b/vllm_npu/worker/__init__.py index b2dcbf3..e69de29 100644 --- a/vllm_npu/worker/__init__.py +++ b/vllm_npu/worker/__init__.py @@ -1 +0,0 @@ -"""Ascend NPU worker implementation.""" diff --git a/vllm_npu/worker/block_table.py b/vllm_npu/worker/block_table.py new file mode 100644 index 0000000..307eb83 --- /dev/null +++ b/vllm_npu/worker/block_table.py @@ -0,0 +1,312 @@ +from typing import Optional, Union + +import numpy as np +import torch +from vllm.distributed import get_dcp_group +from vllm.utils import cdiv + + +class BlockTable: + + def __init__(self, + block_size: int, + max_num_reqs: int, + max_num_blocks_per_req: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + kernel_sizes: Union[list[int], None] = None): + self.max_num_reqs = max_num_reqs + self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens + self.pin_memory = pin_memory + self.device = device + self.physical_block_size = block_size + # If kernel_sizes is None or [0], use physical block size (no splitting) + if kernel_sizes is None or kernel_sizes == [0]: + self.block_size = block_size + self.logical_block_size = block_size + self.blocks_per_phys_block = 1 + self.use_hybrid_blocks = False + else: + # Find the first kernel size that divides physical_block_size evenly + selected_kernel_size = None + for kernel_size in kernel_sizes: + if kernel_size > 0 \ + and self.physical_block_size % kernel_size == 0: + selected_kernel_size = kernel_size + break + + if selected_kernel_size is None: + raise ValueError( + f"None of the kernel sizes {kernel_sizes} can divide " + f"physical block size {self.physical_block_size} evenly") + + self.block_size = selected_kernel_size + self.logical_block_size = selected_kernel_size + self.blocks_per_phys_block = (self.physical_block_size // + self.logical_block_size) + if self.blocks_per_phys_block > 1: + self.use_hybrid_blocks = True + else: + self.use_hybrid_blocks = False + + if self.use_hybrid_blocks: + logical_table_size = (max_num_blocks_per_req * + self.blocks_per_phys_block) + else: + logical_table_size = max_num_blocks_per_req + + self.block_table = torch.zeros( + (max_num_reqs, logical_table_size), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, logical_table_size), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.kernel_sizes = kernel_sizes + + def append_row( + self, + block_ids, + row_idx: int, + ) -> None: + if not block_ids: + return + block_ids = np.array(block_ids) + if self.use_hybrid_blocks: + block_ids = self._convert_physical_to_logical_blocks(block_ids) + + num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] += num_blocks + + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + def swap_row(self, src: int, tgt: int) -> None: + num_blocks_src = self.num_blocks_per_row[src] + num_blocks_tgt = self.num_blocks_per_row[tgt] + self.num_blocks_per_row[src] = num_blocks_tgt + self.num_blocks_per_row[tgt] = num_blocks_src + + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with an interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // virtual_block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = (req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + + logical_block_idx) + + block_numbers = self.block_table_np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calculate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calculate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping_np[:req_indices.shape[0]] = np.where( + mask, slot_mapping, -1) + else: + assert self.kernel_sizes is not None + if self.block_size == self.kernel_sizes[0]: + # IMPORTANT: In hybrid mode, positions are in logical block space, + # but we need to map them to the correct logical block table indices + logical_block_idx = positions // self.block_size + + # Account for the expanded logical table + # (always needed with unified tensor) + # Each physical block is split into multiple logical blocks + # The logical table has been expanded to accommodate this + block_table_indices = ( + req_indices * self.max_num_blocks_per_req * + self.blocks_per_phys_block + logical_block_idx) + + block_numbers = self.block_table_np.ravel( + )[block_table_indices] + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) + + def commit_block_table(self, num_reqs: int) -> None: + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def commit_slot_mapping(self, num_tokens: int) -> None: + self.slot_mapping[:num_tokens].copy_( + self.slot_mapping_cpu[:num_tokens], non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + def _convert_physical_to_logical_blocks( + self, physical_blocks: np.ndarray) -> np.ndarray: + """Convert physical block IDs to logical block IDs.""" + if not self.use_hybrid_blocks: + return physical_blocks + + # Create logical block IDs by splitting each physical block + logical_blocks: list[int] = [] + for phys_block in physical_blocks: + # Convert physical block to multiple logical blocks + # Physical block 1 becomes logical blocks + # [1*split_ratio, 1*split_ratio+1, ...] + # But we need to account for the fact that block 0 is special + base_logical = phys_block * self.blocks_per_phys_block + logical_blocks.extend( + range(base_logical, base_logical + self.blocks_per_phys_block)) + + return np.array(logical_blocks, dtype=np.int32) + + def get_device_tensor(self) -> torch.Tensor: + """Returns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np + + +class MultiGroupBlockTable: + """The BlockTables for each KV cache group.""" + + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0, + kernel_sizes: Optional[list[list[int]]] = None) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + + if kernel_sizes is None: + kernel_sizes = [[0]] * len(block_sizes) + # Ensure kernel_sizes matches block_sizes length + elif len(kernel_sizes) == 1 and len(block_sizes) > 1: + kernel_sizes = kernel_sizes * len(block_sizes) + elif len(kernel_sizes) != len(block_sizes): + raise ValueError( + f"kernel_sizes length ({len(kernel_sizes)}) must match " + f"block_sizes length ({len(block_sizes)})") + + # Use zip to pair block_sizes with kernel_sizes one-to-one + self.block_tables = [ + BlockTable( + block_size, max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens), max_num_batched_tokens, + pin_memory, device, kernel_size_list) + for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + ] + + def append_row(self, block_ids: tuple[list[int], ...], + row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) + + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) + + def move_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.move_row(src, tgt) + + def swap_row(self, src: int, tgt: int) -> None: + for block_table in self.block_tables: + block_table.swap_row(src, tgt) + + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + for block_table in self.block_tables: + block_table.compute_slot_mapping(req_indices, positions) + + def commit_block_table(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit_block_table(num_reqs) + + def commit_slot_mapping(self, num_tokens: int) -> None: + for block_table in self.block_tables: + block_table.commit_slot_mapping(num_tokens) + + def clear(self) -> None: + for block_table in self.block_tables: + block_table.clear() + + def __getitem__(self, idx: int) -> "BlockTable": + """Returns the BlockTable for the i-th KV cache group.""" + return self.block_tables[idx] diff --git a/vllm_npu/worker/model_runner_v1.py b/vllm_npu/worker/model_runner_v1.py new file mode 100644 index 0000000..3dfd4f2 --- /dev/null +++ b/vllm_npu/worker/model_runner_v1.py @@ -0,0 +1,3713 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py +# + +import copy +import gc +import itertools +import math +import re +import time +from collections import defaultdict +from collections.abc import Iterator +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from dataclasses import dataclass +from multiprocessing import Manager +from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, + Union, cast) + +import numpy as np +import numpy.typing as npt +import torch +import torch._dynamo.cache_size +import torch.distributed as dist +import torch.nn as nn +from tqdm import tqdm # type: ignore +from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layer import Attention +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, + get_tp_group, + is_global_first_rank) +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import supports_transcription +from vllm.model_executor.models.interfaces_base import ( + VllmModelForPooling, is_pooling_model, is_text_generation_model) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors +from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LazyLoader, cdiv, get_dtype_size, + is_pin_memory_available) +from vllm.utils.jsontree import json_map_leaves +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +# yapf conflicts with isort for this block +# yapf: disable +from vllm.v1.kv_cache_interface import (AttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, MLAAttentionSpec, + UniformTypeKVCacheSpecs) +# yapf: enable +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsTensors, ModelRunnerOutput, + PoolerOutput) +from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, + gather_mm_placeholders, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) + +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config +from vllm_npu.ascend_forward_context import (MoECommType, + set_ascend_forward_context) +from vllm_npu.attention.attention_mask import AttentionMaskBuilder +from vllm_npu.attention.attention_v1 import AscendAttentionState +from vllm_npu.attention.utils import AscendCommonAttentionMetadata +from vllm_npu.compilation.acl_graph import (ACLGraphWrapper, + set_graph_params, + update_attn_params, + update_mla_attn_params) +from vllm_npu.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_npu.eplb.core.eplb_device_transfer_loader import \ + D2DExpertWeightLoader +from vllm_npu.eplb.core.eplb_utils import EPLBParamUtils +from vllm_npu.eplb.core.eplb_worker import EplbProcess +from vllm_npu.eplb.eplb_updator import EplbUpdator +from vllm_npu.eplb.utils import model_register +from vllm_npu.models.layers.mla import AscendMultiHeadLatentAttention +from vllm_npu.multistream.ms_split import compute_split_seq_index +from vllm_npu.ops.weight_prefetch import WeightPrefetchMethod +from vllm_npu.platform import NPUPlatform +from vllm_npu.sample.logits_processor import build_logitsprocs +from vllm_npu.sample.rejection_sampler import AscendRejectionSampler +from vllm_npu.spec_decode import get_spec_decode_method +from vllm_npu.spec_decode.eagle_proposer import EagleProposer +from vllm_npu.spec_decode.interface import SpecDcodeType +from vllm_npu.spec_decode.mtp_proposer import MtpProposer +from vllm_npu.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, + AscendSocVersion, ProfileExecuteDuration, + enable_sp, get_ascend_soc_version, is_310p, + is_enable_nz, is_moe_model, lmhead_tp_enable) +from vllm_npu.worker.npu_input_batch import CachedRequestState, InputBatch + +if TYPE_CHECKING: + import xgrammar as xgr # type: ignore[import-untyped] + from vllm.v1.core.sched.output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +import torch_npu + +# if true, allow tensor initialization and casting with internal format (e.g., NZ) +torch.npu.config.allow_internal_format = True + +if is_310p(): + torch_npu.npu.set_compile_mode(jit_compile=False) + ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ +else: + ACL_FORMAT = ACL_FORMAT_FRACTAL_ND + + +@dataclass +class GraphCaptureContext: + stream: torch.npu.Stream + + +@contextmanager +def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the NPU graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current NPU stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + graph_capture_context = GraphCaptureContext( + torch.npu.Stream(device=device)) + stream = graph_capture_context.stream + + # we use nullcontext now + maybe_ca_context = nullcontext() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.npu.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.npu.stream(stream), maybe_ca_context: + yield graph_capture_context + + +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.npu.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self._async_copy_ready_event = torch.npu.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.npu.current_stream() + with torch.npu.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) + self._async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self._async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + +class NPUModelRunner(LoRAModelRunnerMixin): + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.load_config = vllm_config.load_config + self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config + self.pin_memory = is_pin_memory_available() + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.block_size = vllm_config.cache_config.block_size + self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, + self.block_size) + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, + decode_max_num_seqs) + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.device = device + if envs_ascend.vllm_npu_ENABLE_PREFETCH_MLP: + self.prefetch_stream = torch.npu.Stream(device=device) + else: + self.prefetch_stream = None + self.dtype = self.model_config.dtype + if envs_ascend.vllm_npu_ENABLE_TOPK_TOPP_OPTIMIZATION: + # TODO: drop the env config to use ascend sampler by default + from vllm_npu.sample.sampler import AscendSampler + + self.sampler = AscendSampler() + else: + from vllm.v1.sample.sampler import Sampler + + self.sampler = Sampler() + self.reorder_batch_threshold: Optional[int] = None + + # Lazy initialization, these will be set after __init__ + self.kv_caches: List[torch.Tensor] = [] + self.attn_groups: list[list[AttentionGroup]] = [] + self.encoder_cache: Dict[str, torch.Tensor] = {} + self.attn_mask = None + self.attn_state = None + self.requests: Dict[str, CachedRequestState] = {} + self.intermediate_tensors: Optional[IntermediateTensors] = None + self.runner_only_attn_layers: set[str] = set() + + self.ascend_config = get_ascend_config() + if self.ascend_config.ascend_scheduler_config.enabled: + self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled + else: + self.chunked_prefill_enabled = True + self.weight_prefetch_method = WeightPrefetchMethod( + self.ascend_config.weight_prefetch_config) + + if self.cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + # use_hybrid_blocks: if hybrid blocks is used. + self.use_hybrid_blocks: bool = False + self.need_accepted_tokens: bool = False + + self.is_multimodal_model = self.model_config.is_multimodal_model + self.is_pooling_model = self.model_config.pooler_config is not None + if self.is_multimodal_model: + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.model_config.get_hidden_size()), + dtype=self.dtype, + device=self.device) + # Set up Attention + self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, + "index_topk") + self.attn_backend = get_attn_backend(0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse) + self.attn_mask_builder = AttentionMaskBuilder( + self.scheduler_config.max_num_batched_tokens, self.dtype, + self.device) + + # Set up speculative decoding. + self.spec_attn_mask = None + self.drafter: Optional[Union[NgramProposer, EagleProposer, + MtpProposer]] = None + self.actual_seq_lengths_q: list[int] = [] + self.decode_token_per_req = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + assert spec_token_num > 0 + self.decode_token_per_req = 1 + spec_token_num + self.spec_attn_mask = torch.triu(torch.ones(2048, + 2048, + dtype=torch.bool), + diagonal=1).to(self.device) + if get_pp_group().is_last_rank: + self.drafter = get_spec_decode_method( + self.speculative_config.method, self.vllm_config, + self.device, self) + self.rejection_sampler = AscendRejectionSampler() + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) + + # kv role + self.is_kv_producer = False + self.is_kv_consumer = False + if vllm_config.kv_transfer_config is not None: + self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer + self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer + + self._may_pad_kv_consumer_num_seq() + + # Persistent batch. + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + self.query_start_loc = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.seq_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + + if self.vllm_config.model_config.use_mla and \ + self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos = torch.ones(self.max_num_reqs * + self.decode_token_per_req, + 1, + 1, + rope_dim, + dtype=self.dtype, + device=self.device) + self.sin = torch.zeros(self.max_num_reqs * + self.decode_token_per_req, + 1, + 1, + rope_dim, + dtype=self.dtype, + device=self.device) + else: + self.cos = None + self.sin = None + + self.uses_mrope = self.model_config.uses_mrope + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), + dtype=torch.int64, + device=self.device) + self.mrope_positions_cpu = torch.zeros( + (3, self.max_num_tokens + 1), + dtype=torch.int64, + device="cpu", + pin_memory=True) + self.mrope_positions_np = self.mrope_positions_cpu.numpy() + + # OPTIMIZATION: Cache the tensors rather than creating them every step. + self.arange_np: npt.NDArray[np.int32] = np.arange(max( + self.max_num_reqs + 1, self.model_config.max_model_len, + self.max_num_tokens), + dtype=np.int32) + # NOTE(woosuk): These tensors are "stateless", i.e., they are literally + # a faster version of creating a new tensor every time. Thus, we should + # not make any assumptions about the values in these tensors. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=True) + self.positions_np = self.positions_cpu.numpy() + + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.seq_lens_np = self.seq_lens_cpu.numpy() + + self.use_aclgraph = self._use_aclgraph() + self.aclgraph_batch_sizes = list( + reversed(self.compilation_config.cudagraph_capture_sizes)) + + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens + # aclgraph dispatcher for runtime aclgraph dispatching. + self.aclgraph_dispatcher = CudagraphDispatcher(self.vllm_config) + # Cached outputs. + self._draft_token_ids: Optional[Union[list[list[int]], + torch.Tensor]] = None + + # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True + self.in_profile_run = False + + self._init_mc2_tokens_capacity() + if is_moe_model(vllm_config): + self.reserved_mc2_mask = torch.zeros( + self.mc2_tokens_capacity, + dtype=torch.bool, + device=self.device, + ) + else: + self.reserved_mc2_mask = None + self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path + if self.dynamic_eplb: + EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb) + EPLBParamUtils.check_expert_map_record_path( + self.ascend_config.expert_map_record_path) + self.is_eplb_warmuped = False + self.policy_type = self.ascend_config.eplb_policy_type + self.eplb_loader = D2DExpertWeightLoader() + self.manager = Manager() + self.shared_dict = self.manager.dict({ + "expert_map": None, + "moe_load": None, + "expert_maps": None + }) + self.eplb_process = EplbProcess(shared_dict=self.shared_dict, + policy_type=self.policy_type, + enable_d2d=True) + self.process = self.eplb_process._launch_process() + ascend_config = get_ascend_config() + self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader, + self.eplb_process, self.process) + + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = torch.npu.Stream() if \ + self.use_async_scheduling else None + # Input Batch + # NOTE(Chen): Ideally, we should initialize the input batch inside + # `initialize_kv_cache` based on the kv cache config. However, as in + # https://github.com/vllm-project/vllm/pull/18298, due to some unknown + # reasons, we have to initialize the input batch before `load_model`, + # quantization + weight offloading will fail otherwise. As a temporary + # solution, we initialize the input batch here, and re-initialize it + # in `initialize_kv_cache` if the block_sizes here is different from + # the block_sizes in the kv cache config. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, + kernel_block_sizes=[[self.vllm_config.cache_config.block_size]], + ) + self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + + def _may_pad_kv_consumer_num_seq(self): + # For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario, + # we may want to pad self.max_num_seqs in kv_consumer nodes to avoid + # exceeding a sequence length limit (16 tokens) in npu_fused_infer_attention_score operation + pass + + def _init_mc2_tokens_capacity(self): + # NOTE: To be clear, we need to make sure that during graph capture, the number of + # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes, + # the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512). + if self.compilation_config.cudagraph_capture_sizes: + max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0] + else: + # NOTE: To save memory, we cap the max number of tokens to 512. + max_num_tokens = min( + self.max_num_reqs * self.uniform_decode_query_len, 512) + tp_size = self.parallel_config.tensor_parallel_size + # Use integer arithmetic for ceiling division. + num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size + self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size + + # Only relevant for multimodal models + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + self.model_config) + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) + + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + # Bfloat16 torch tensors cannot be directly cast to a numpy array, so + # if a bfloat16 buffer is needed without a corresponding numpy array, + # don't bother instantiating the numpy array. + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) + + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = (torch.cat( + [ + output_token_ids, + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), + ], + dim=1) == -1).int().argmax(-1).cpu().numpy() + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + + def _use_aclgraph(self) -> bool: + return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if pooling_params: + assert (task := pooling_params.task) is not None, ( + "You did not set `task` in the API") + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + backward_kwargs = {} + backward_kwargs["mm_features"] = new_req_data.mm_features + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + **backward_kwargs, + ) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(self.requests[req_id]) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = (num_computed_tokens + len(new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) + else: + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[ + req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + if spec_token_ids: + num_spec_tokens = len(spec_token_ids) + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + assert req_state.mm_features is not None + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _sync_metadata_across_dp( + self, num_tokens: int, with_prefill: bool, enable_dbo: bool + ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in + # our case, we still need to sync the other two flags as well. So we need to + # include them in the all_reduce operation, and more over, we CANNOT skip it + # even if we are running in eager mode, which harms performance. + # FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here + # immediately once the other two flags are no longer needed. + if self.dp_size == 1: + return num_tokens, None, with_prefill, enable_dbo + + # Sync num_tokens, with_prefill, enable_dbo across dp ranks + num_tokens_tensor = torch.tensor([ + num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) + ], + dtype=torch.int32, + device="npu") + + flags_tensor = torch.tensor( + [int(with_prefill), int(not enable_dbo)], + dtype=torch.int32, + device="npu") + + packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) + + dist.all_reduce(packed_tensor, group=get_dp_group().device_group) + + # Unpack the results + num_tokens_across_dp = packed_tensor[:-2] + synced_flags = packed_tensor[-2:] + + max_tokens_across_dp = torch.max(num_tokens_across_dp).item() + global_with_prefill = bool(synced_flags[0]) + global_enable_dbo = not bool(synced_flags[1]) + + # Create a tensor for num_tokens_after_padding + num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * + self.dp_size, + device="cpu", + dtype=torch.int32) + + return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo + + def _check_dbo_is_valid(self, query_lens: torch.Tensor, + attn_state: AscendAttentionState, + num_tokens: int) -> bool: + # do the checks for dp + dbo + if attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + return False + # considering the case that one dp rank may enable dbo while others may not + if not self.vllm_config.model_config.use_mla or not envs_ascend.vllm_npu_ENABLE_DBO: + return False + # TODO: remove it if token-level microbatch is enabled + [token_index, + seq_index] = compute_split_seq_index(query_lens, attn_state, + num_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + query_lens) or num_tokens < 256: + return False + return True + + def get_model(self) -> nn.Module: + # get raw model out of the aclgraph wrapper. + if isinstance(self.model, ACLGraphWrapper): + return self.model.unwrap() + return self.model + + def get_supported_generation_tasks(self) -> "list[GenerationTask]": + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + + def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + + def _make_attention_mask(self, seq_lens, position, + attn_state) -> torch.Tensor: + # Pooling situation. + if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": + return self.attn_mask_builder.get_pooling_mask(self.device) + # Chunk Prefill situation. + elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: + return self.attn_mask_builder.get_splitfuse_attn_mask() + + # Prefill without cache situation. + elif attn_state == AscendAttentionState.PrefillNoCache: + max_seq_len = max(seq_lens.max().item(), 0) + return self.attn_mask_builder.get_attn_mask( + max_seq_len, self.dtype, self.device) + # Prefill with cache hit. + elif attn_state == AscendAttentionState.PrefillCacheHit: + return self.attn_mask_builder.get_attn_mask( + 128, self.dtype, self.device) + # Decode-only situation. + else: + return None + + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions_np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output) + encoder_outputs = [] + + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=True, + ): + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **mm_kwargs_group) + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=num_items, + ) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) + + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return [], [] + # Batch the multi-modal inputs. + mm_kwargs = list[MultiModalKwargsItem]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + assert req_state.mm_features is not None + for mm_input_id in encoder_input_ids: + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + + return mm_kwargs, mm_hashes_pos + + def _gather_mm_embeddings( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[torch.Tensor], torch.Tensor]: + + def _iter_mm_features(req_state: CachedRequestState): + assert req_state.mm_features is not None + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + yield mm_feature.identifier, pos_info, getattr( + pos_info, "is_embed", None) + + mm_embeds: list[torch.Tensor] = [] + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_mm_embed = self.is_mm_embed.cpu + is_mm_embed[:total_num_scheduled_tokens] = False + + req_start_idx = 0 + + for req_id in self.input_batch.req_ids: + mm_embeds_req: list[torch.Tensor] = [] + + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + + for mm_hash, pos_info, is_embed in _iter_mm_features(req_state): + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length + + if start_pos >= num_computed_tokens + num_scheduled_tokens: + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None, \ + f"Encoder cache miss for {mm_hash}." + + if is_embed is not None: + is_embed = is_embed[start_idx:end_idx] + + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ + = True if is_embed is None else is_embed + + mm_embeds_item = gather_mm_placeholders( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds_req.append(mm_embeds_item) + + mm_embeds.extend(mm_embeds_req) + req_start_idx += num_scheduled_tokens + + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) + + return mm_embeds, is_mm_embed + + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + NPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the NPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= (prev_index == flattened_index) + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the NPU first. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], + non_blocking=True) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids_cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, + 0], + non_blocking=True) + return + # Upload the index tensors asynchronously + # so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor(flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to( + self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + self.input_ids.scatter_(dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0]) + + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + + if self.reorder_batch_threshold is not None: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, + int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor], int]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit_block_table(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = num_scheduled_tokens.max() + num_valid_tokens = np.array([ + num_tokens - + len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32) + + if (self.use_aclgraph and total_num_scheduled_tokens + <= self.aclgraph_batch_sizes[-1]): + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil( + total_num_scheduled_tokens / tp_size) * tp_size + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just max_tokens_across_dp_cpu + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._sync_metadata_across_dp(num_input_tokens, + with_prefill, enable_dbo) + + # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens + # We should consider removing maybe_padded_num_tokens later + num_input_tokens = maybe_padded_num_tokens + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # Prepare input_ids. + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Prepare some information for building Attention-Metadata + # Compute and commit slot mapping + self.input_batch.block_table.compute_slot_mapping( + req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) + + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.query_start_loc[num_reqs + 1:].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + + # Copy the tensors to the NPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_( + self.positions_cpu[:num_input_tokens], non_blocking=True) + + # Make Attention metadata + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, + position=positions_cpu, + attn_state=attn_state) + self.attn_state = attn_state # type: ignore + + self.with_prefill = with_prefill + self.num_tokens_across_dp = num_tokens_across_dp + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) + attn_metadata: dict[str, Any] = {} + + # _prepare_inputs may reorder the batch, so we must gather + # multi-modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output) + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:total_num_scheduled_tokens] + model_type = self.vllm_config.model_config.hf_config.model_type + if model_type == "qwen2_5_vl" or model_type == "qwen3_vl_moe": + inputs_embeds = self.model.get_input_embeddings( + input_ids, + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + else: + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:total_num_scheduled_tokens].copy_( + inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the ACL graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + maybe_padded_num_tokens) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_( + v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors({ + k: v[:num_input_tokens] + for k, v in self.intermediate_tensors.items() + }) + + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to( + self.device, non_blocking=True) + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() + + # Used in the below loop. + # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + spec_decode_common_attn_metadata = None + if use_spec_decode and self.need_accepted_tokens: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens, ), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[: + total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) + self.slot_mapping[total_num_scheduled_tokens:].fill_(0) + + # Make AscendCommonAttentionMetadata + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + # TODO: change this to the right block table for linear attn + block_table_tensor=blk_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + enable_dbo_across_dp=enable_dbo, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), + max_query_len=max_num_scheduled_tokens, + graph_pad_size=self.graph_pad_size, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + common_prefix_len = 0 + extra_attn_metadata_args = {} + builder = attn_group.get_metadata_builder() + if isinstance(builder, GDNAttentionMetadataBuilder + ) or self.model_config.runner_type == "pooling": + if use_spec_decode: + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_draft_tokens. + gpu[:num_reqs], + ) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) + else: + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad( + logits_indices, + (0, max_num_reqs_across_dp - logits_indices.shape[0])) + + return (attn_metadata, positions, num_scheduled_tokens, + num_input_tokens, num_tokens_across_dp, + maybe_padded_num_tokens, logits_indices, spec_decode_metadata, + input_ids, inputs_embeds, intermediate_tensors, + max_num_scheduled_tokens) + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + maybe_padded_num_tokens, + input_ids, positions, + intermediate_tensors, + inputs_embeds): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + maybe_padded_num_tokens, + self.speculative_config) + else: + update_attn_params(self.update_stream, forward_context, + maybe_padded_num_tokens, + self.vllm_config.kv_transfer_config) + + if get_forward_context().sp_enabled: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] + return hidden_states + + def _build_attn_state(self, num_reqs, num_scheduled_tokens, + num_valid_tokens): + ascend_config = get_ascend_config() + if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillNoCache + # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. + elif np.all(num_scheduled_tokens == 1): + attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # SpecDecoding now supports seq_len=1 and seq_len=2 + # In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1 + attn_state = AscendAttentionState.SpecDecoding + # Speculative decoding. + elif np.all(num_valid_tokens == 1): + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + attn_state = AscendAttentionState.SpecDecoding + else: + attn_state = AscendAttentionState.ChunkedPrefill + # splitfuse + elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.PrefillCacheHit + return attn_state + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + self.graph_pad_size = -1 + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + maybe_padded_num_tokens): + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + return input_ids, positions + + def _calc_spec_decode_metadata( + self, + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + # bonus_logits_indices: [ 3, 4, 7, 8, 10] + + # Compute the logits indices. + # [4, 1, 3, 1, 2] + num_sampled_tokens = num_draft_tokens + 1 + # Step 1. [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + total_num_sampled_tokens = cu_num_sampled_tokens[-1] + # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, + num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets + # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_indices = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += arange + + # Compute the bonus logits indices. + bonus_logits_indices = cu_num_sampled_tokens - 1 + + # Compute the draft logits indices. + # [3, 3, 5, 5, 6] + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + total_num_draft_tokens = cu_num_draft_tokens[-1] + # [0, 0, 0, 3, 3, 5] + cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, + num_draft_tokens) + # [0, 1, 2, 0, 1, 0] + arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets + # [0, 0, 0, 5, 5, 9] + target_logits_indices = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + # [0, 1, 2, 5, 6, 9] + target_logits_indices += arange + + # TODO: Optimize the CPU -> NPU copy. + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( + self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).to(self.device, + non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).to( + self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( + self.device, non_blocking=True) + + # Compute the draft token ids. + # draft_token_indices: [ 1, 2, 3, 105, 106, 208] + draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) + return metadata + + def apply_grammar_bitmask( + self, + scheduler_output: "SchedulerOutput", + logits: torch.Tensor, + ) -> torch.Tensor: + grammar_bitmask = scheduler_output.grammar_bitmask + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(self.input_batch.req_id_to_index.items(), + key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.zeros_like(grammar_bitmask, + shape=(logits.shape[0], + grammar_bitmask.shape[1])) + cumulative_index = 0 + seq = sorted(scheduler_output.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = torch.from_numpy(grammar_bitmask) + + # NOTE: + # 1. XGrammar bitmask applying only supports CPU and GPU. + # 2. The logits and bitmask should be on the same device. + # 3. XGrammar logits on CPU only supports float32 dtype. + logits_dtype = logits.dtype + logits = logits.to("cpu").float() + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask, + indices=out_indices, + ) + return logits.to(self.device).to(logits_dtype) + + def propose_draft_token_ids( + self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + scheduler_output: "SchedulerOutput", + spec_decode_metadata: SpecDecodeMetadata, + positions: torch.Tensor, + num_scheduled_tokens: int, + hidden_states: torch.Tensor, + attn_metadata: dict[str, Any], + aux_hidden_states: torch.Tensor = None, + ) -> Optional[list[list[int]]]: + if not self.drafter: + # Speculative decoding is not enabled. + draft_token_ids = None + else: + draft_token_ids = self.drafter.generate_token_ids( + valid_sampled_token_ids, sampling_metadata, scheduler_output, + spec_decode_metadata, positions, num_scheduled_tokens, + hidden_states, attn_metadata, aux_hidden_states) + return draft_token_ids + + def _pool( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None, + kv_connector_output: Optional["KVConnectorOutput"] = None, + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" + + hidden_states = hidden_states[:num_scheduled_tokens] + pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] + + model = cast(VllmModelForPooling, self.model) + raw_pooler_output = model.pooler( + hidden_states=hidden_states, + pooling_metadata=pooling_metadata, + ) + raw_pooler_output = json_map_leaves( + lambda x: x.to("cpu", non_blocking=True), + raw_pooler_output, + ) + torch.npu.synchronize() + + pooler_output: list[Optional[torch.Tensor]] = [] + for raw_output, seq_len, prompt_len in zip( + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): + output = raw_output if seq_len == prompt_len else None + pooler_output.append(output) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + ) + + def _select_moe_comm_method(self, num_tokens: int, + with_prefill: bool) -> Optional[MoECommType]: + """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all + are designed for expert parallelism. + 2. If expert parallel is enabled, we need to consider the soc version and the + number of tokens. This is based on the observation that all-gather is more + efficient than all-to-all when running on A2. + + a. For A2, we choose from MC2 and all-gather. + + b. For A3, we choose from MC2 and all-to-all. + + In both cases, we use MC2 when the number of tokens is smaller than + a its capacity threshold. + + Args: + num_tokens (int): The number of tokens in the current batch. + + Raises: + ValueError: If the soc version is unsupported. + + Returns: + MoECommType: The selected MoE communication method. + """ + if not is_moe_model(self.vllm_config): + return None + + soc_version = get_ascend_soc_version() + quant_type = getattr(self.vllm_config.model_config.hf_config, + 'moe_quantize', None) + model_type = self.vllm_config.model_config.hf_config.model_type + + if not self.parallel_config.enable_expert_parallel: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendSocVersion.A2}: + if (num_tokens <= self.mc2_tokens_capacity + and self.parallel_config.world_size_across_dp >= 16): + moe_comm_type = MoECommType.MC2 + else: + # Currently, w4a8_dynamic does not support allgatherep + if quant_type == "w4a8_dynamic": + moe_comm_type = MoECommType.ALLTOALL + else: + moe_comm_type = MoECommType.ALLGATHER + + elif soc_version in {AscendSocVersion.A3}: + moe_comm_type = (MoECommType.MC2 + if num_tokens <= self.mc2_tokens_capacity else + MoECommType.ALLTOALL) + else: + raise ValueError(f"Unsupported soc_version: {soc_version}") + + if moe_comm_type == MoECommType.ALLGATHER and with_prefill: + if enable_sp(): + moe_comm_type = MoECommType.ALLGATHER + else: + moe_comm_type = MoECommType.NAIVE_MULTICAST + + # PanguProMoE only supports allgather + if model_type == "PanguProMoE": + moe_comm_type = MoECommType.ALLGATHER + + if is_global_first_rank(): + logger.debug(f"num_tokens: {num_tokens}, " + f"moe_comm_type: {moe_comm_type}") + return moe_comm_type + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + with ProfileExecuteDuration().capture_async("prepare input"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + logger.debug( + "skip this step for we receive the data from remote disaggregate prefill node" + ) + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + + (attn_metadata, positions, num_scheduled_tokens_np, + num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, + logits_indices, spec_decode_metadata, input_ids, inputs_embeds, + intermediate_tensors, + max_query_len) = (self._prepare_inputs(scheduler_output, + intermediate_tensors)) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + + moe_comm_type = self._select_moe_comm_method(num_input_tokens, + self.with_prefill) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + aclgraph_runtime_mode, batch_descriptor = \ + self.aclgraph_dispatcher.dispatch(batch_descriptor) + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + num_actual_tokens=scheduler_output. + total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method): + self.maybe_setup_kv_connector(scheduler_output) + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, self.with_prefill, maybe_padded_num_tokens, + input_ids, positions, intermediate_tensors, inputs_embeds) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) + + aux_hidden_states = None + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, aux_hidden_states = hidden_states + + kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) + finished_sending = None + finished_recving = None + with ProfileExecuteDuration().capture_async("post process"): + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving, kv_connector_output) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[:self.input_batch.num_reqs] + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + if lmhead_tp_enable() and logits is not None: + logits = logits[:len(spec_decode_metadata.logits_indices)] + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[ + spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[ + spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) + + discard_sampled_tokens_req_indices: list[int] = [] + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + if self.speculative_config: + self._draft_token_ids = self.propose_draft_token_ids( + valid_sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + scheduler_output.total_num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + ) + + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + extra_args = ({"kv_connector_output": kv_connector_output}) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + **extra_args, + ) + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [ + f"[{tag}]:{duration:.2f}ms" + for tag, duration in durations.items() + ] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + logger.info("Profile execute duration [%s]:%s", captured_name, + " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + if not self.use_async_scheduling: + return model_runner_output + + return AsyncNPUModelRunnerOutput( + model_runner_output=model_runner_output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + req_ids = self.input_batch.req_ids + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + with set_ascend_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfer(scheduler_output)) + # For the case of no forward caused by receiving remote kv, + # one round of dummy inference is necessary + # to prevent hang over the collective calls. + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfer( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + + def _build_dummy_attn_metadata( + self, + with_prefill: bool, + num_reqs: int, + num_tokens: int, + max_query_len: int, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + ) -> Optional[dict[str, Any]]: + attn_metadata: Optional[dict[str, Any]] = None + + if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL: + assert with_prefill is False, \ + "Full decode graph only supports uniform batch now." + + attn_metadata = {} + + seq_lens = max_query_len + self.seq_lens_np[:num_reqs] = seq_lens + self.seq_lens_np[num_reqs:] = 0 + + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_table_tensor = self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor( + [0] + self.actual_seq_lengths_q[:num_reqs], + device=self.device, + dtype=torch.int32), + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + max_query_len=max_query_len, + decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, + ) + attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and \ + self.speculative_config.method == "deepseek_mtp": + attn_state = AscendAttentionState.SpecDecoding + + for attn_group in self.attn_groups[kv_cache_group_id]: + builder = attn_group.get_metadata_builder() + attn_metadata_i = builder.build_for_graph_capture( + common_attn_metadata, attn_state, self.get_model()) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + return attn_metadata + + def _generate_dummy_run_hidden_states(self, with_prefill, + is_torchair_compile, input_ids, + positions, attn_metadata, num_tokens, + intermediate_tensors, inputs_embeds): + hidden_states = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + forward_context = get_forward_context() + assert forward_context is not None + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing: + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + num_tokens, self.speculative_config) + else: + update_attn_params(self.update_stream, forward_context, + num_tokens, + self.vllm_config.kv_transfer_config) + + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: + hidden_states, _ = hidden_states + else: + hidden_states = hidden_states + return hidden_states + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + is_torchair_compile: bool = False, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + uniform_decode: bool = False, + ) -> torch.Tensor: + # only support eager mode and piecewise graph now + assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL + } + + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + + # Force dummy run on prefill stage when this node is deemed as kv producer. + if self.is_kv_producer and not self.is_kv_consumer: + with_prefill = True + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, + _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) + + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.seperate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else \ + num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.max_num_reqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + if with_prefill: + num_reqs = num_tokens + else: + num_reqs = (num_tokens + self.decode_token_per_req - + 1) // self.decode_token_per_req + num_reqs = min(num_reqs, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, + dtype=self.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + # filter out the valid batch descriptor + _ag_mode, batch_descriptor = \ + self.aclgraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + if aclgraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \ + aclgraph_runtime_mode == _ag_mode, ( + f"Aclgraph runtime mode mismatch at dummy_run. " + f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.") + else: + aclgraph_runtime_mode = _ag_mode + + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( + False, + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, + ) + + need_dummy_logits = (not self.in_profile_run + and lmhead_tp_enable()) + + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + if not need_dummy_logits: + return None + return self.model.compute_logits(hidden_states[dummy_indices]) + + def dummy_drafter_compute_logits(hidden_states): + if not need_dummy_logits or self.drafter is None: + return + if hasattr(self.drafter, "model") and hasattr( + self.drafter.model, "compute_logits"): + return self.drafter.model.compute_logits( + hidden_states[dummy_indices]) + + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method): + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill, is_torchair_compile, input_ids, positions, + attn_metadata, num_tokens, intermediate_tensors, + inputs_embeds) + dummy_compute_logits(hidden_states) + + if self.drafter: + self.drafter.dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + skip_attn=True, + num_reqs=num_reqs, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, + dummy_compute_logits=dummy_drafter_compute_logits) + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + return hidden_states + + @contextmanager + def set_in_profile_run(self): + self.in_profile_run = True + try: + yield + finally: + self.in_profile_run = False + + def profile_run(self) -> None: + # Trigger compilation for general shape. + with self.set_in_profile_run(): + hidden_states = self._dummy_run(self.max_num_tokens, + with_prefill=True) + # MC2 will consume additional NPU memory. + # Therefore, we need to run the MC2 path once here to complete its initialization, + # allowing vLLM to correctly estimate the maximum memory required. + if self.max_num_tokens > self.mc2_tokens_capacity and \ + self._select_moe_comm_method( + self.mc2_tokens_capacity, + with_prefill=True) == MoECommType.MC2: + self._dummy_run(self.mc2_tokens_capacity, with_prefill=True) + + output = None + if get_pp_group().is_last_rank: + if self.is_pooling_model: + output = self._dummy_pooler_run(hidden_states) + else: + # For profile, have maximum num_reqs and that collectively have + # maximum num_tokens. + min_tokens_per_req = self.max_num_tokens // self.max_num_reqs + num_scheduled_tokens_list = [min_tokens_per_req + ] * self.max_num_reqs + num_scheduled_tokens_list[ + -1] += self.max_num_tokens % self.max_num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + # TODO: need to rum a dummy sampler for generate task + hidden_states = hidden_states[logit_indices] + output = self.model.compute_logits(hidden_states) + + NPUPlatform.synchronize() + del hidden_states, output + self.encoder_cache.clear() + gc.collect() + + def _dummy_pooler_run_task( + self, + hidden_states: torch.Tensor, + task: PoolingTask, + ) -> PoolerOutput: + num_tokens = hidden_states.shape[0] + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + req_num_tokens = num_tokens // num_reqs + + dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device) + + model = cast(VllmModelForPooling, self.get_model()) + dummy_pooling_params = PoolingParams(task=task) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(dummy_pooling_params) + + dummy_prompt_lens = torch.tensor( + num_scheduled_tokens_list, + device="cpu", + ) + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) + + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + + try: + return model.pooler(hidden_states=hidden_states, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> PoolerOutput: + # Find the task that has the largest output for subsequent steps + output_size = dict[PoolingTask, float]() + for task in self.get_supported_pooling_tasks(): + # Run a full batch with each task to ensure none of them OOMs + output = self._dummy_pooler_run_task(hidden_states, task) + output_size[task] = sum(o.nbytes for o in output) + del output # Allow GC + + max_task = max(output_size.items(), key=lambda x: x[1])[0] + return self._dummy_pooler_run_task(hidden_states, max_task) + + def eplb_warmup(self): + if self.dynamic_eplb and not self.is_eplb_warmuped: + self.is_eplb_warmuped = True + self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + self.eplb_loader.set_adator(self.eplb_adaptor) + self.eplb_updator.set_adaptor(self.eplb_adaptor) + self.eplb_updator.warm_up_eplb() + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + + with DeviceMemoryProfiler() as m: # noqa: SIM117 + self.model = get_model(vllm_config=self.vllm_config) + if self.dynamic_eplb: + model_register(self.model, self.model_config) + if is_310p(): + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, + RowParallelLinear) + for module in self.model.modules(): + if isinstance(module, + (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear)): + module.weight.data = self._convert_torch_format( + module.weight.data) + if self.drafter: + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) + if self.drafter.name == SpecDcodeType.EAGLE3: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + + if self.lora_config: + self.model = self.load_lora_model(self.model, self.vllm_config, + self.device) + logger.info("Loading model weights took %.4f GB", + m.consumed_memory / float(2**30)) + + # wrap the model with full graph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.update_stream: torch.npu.Stream = torch.npu.Stream() + set_graph_params(self.compilation_config.cudagraph_capture_sizes) + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + + def _convert_torch_format(self, tensor): + if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \ + and not is_enable_nz(tensor.dtype): + return tensor + tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) + return tensor + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + kv_cache_config = deepcopy(kv_cache_config) + self.kv_cache_config = kv_cache_config + self.may_add_encoder_only_layers_to_kv_cache_config() + # NOTE(cmq): initialize_attn_backend must before using self.attn_groups + self.initialize_attn_backend(kv_cache_config) + self.use_hybrid_blocks = (len(self.attn_groups) > 1) + # NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`. + self.need_accepted_tokens = any([ + isinstance(attn_group[0].kv_cache_spec, MambaSpec) + for attn_group in self.attn_groups + ]) + + self.may_reinitialize_input_batch(kv_cache_config) + + if self.use_sparse: + kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( + kv_cache_config) + elif self.model_config.is_deepseek_mla: + kv_caches = self.initialize_kv_cache_tensors_deepseek_mla( + kv_cache_config) + else: + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + + def _align_memory(self, tensor: torch.Tensor, + alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + + def initialize_kv_cache_tensors_deepseek_sfa( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + tensor_size = kv_cache_sizes[layer_name] + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr( + attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla and not self.use_sparse: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + alignment = 2 * 1024 * 1024 + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_cache_shape = (num_blocks, block_size, num_kv_heads, + nope_dim) + rope_cache_shape = (num_blocks, block_size, num_kv_heads, + rope_dim) + #### k cache + # TODO(zzzzwwjj): wait transformers add these params + k_cache_shape = (num_blocks, block_size, 1, 128) + if self.vllm_config.kv_transfer_config is None: + # For no disaggregate pd scenario, allocate kv cache in normal way + rope_cache = torch.zeros(rope_cache_shape, + dtype=dtype, + device=self.device) + nope_cache = torch.zeros(nope_cache_shape, + dtype=dtype, + device=self.device) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) + + #### k cache + k_cache = torch.zeros(k_cache_shape, + dtype=dtype, + device=self.device) + k_cache = self._convert_torch_format(k_cache) + else: + + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + #### k cache + # TODO(zzzzwwjj): wait transformers add these params + k_allocate_shape = num_blocks * block_size * 1 * 128 + k_allocate_shape_alignment = k_allocate_shape + alignment + k_cache = torch.zeros(k_allocate_shape_alignment, + dtype=dtype, + device=self.device) + + nope_cache = self._align_memory( + nope_cache, + alignment)[:nope_allocate_shape].view(nope_cache_shape) + rope_cache = self._align_memory( + rope_cache, + alignment)[:rope_allocate_shape].view(rope_cache_shape) + k_cache = self._align_memory( + k_cache, + alignment)[:k_allocate_shape].view(k_cache_shape) + + kv_caches[layer_name] = (nope_cache, rope_cache, k_cache) + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def initialize_kv_cache_tensors_deepseek_mla( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + tensor_size = kv_cache_sizes[layer_name] + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and not self.model_config.is_deepseek_mla: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + alignment = 2 * 1024 * 1024 + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_cache_shape = (num_blocks, block_size, num_kv_heads, + nope_dim) + rope_cache_shape = (num_blocks, block_size, num_kv_heads, + rope_dim) + if self.vllm_config.kv_transfer_config is None: + # For no disaggregate pd scenario, allocate kv cache in normal way + rope_cache = torch.zeros(rope_cache_shape, + dtype=dtype, + device=self.device) + nope_cache = torch.zeros(nope_cache_shape, + dtype=dtype, + device=self.device) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) + else: + + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + nope_cache = self._align_memory( + nope_cache, + alignment)[:nope_allocate_shape].view(nope_cache_shape) + rope_cache = self._align_memory( + rope_cache, + alignment)[:rope_allocate_shape].view(rope_cache_shape) + kv_caches[layer_name] = (nope_cache, rope_cache) + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # init kv cache tensors + kv_cache_raw_tensors: dict[str, Union[torch.Tensor, + Optional[torch.Tensor]]] = {} + # llmdatadist need the addr of cache tensor be aligned with 2M + alignment = 2 * 1024 * 1024 + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + # TODO: REFACTOR ME to sharing hybrid cache + for idx in range(len(kv_cache_tensor.shared_by)): + layer_name = kv_cache_tensor.shared_by[idx] + if "linear_attn" in layer_name: + # for mamba linear attention + for layer_name_inner in kv_cache_tensor.shared_by: + if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \ + layer_name_inner in kv_cache_raw_tensors.keys(): + continue + if self.vllm_config.kv_transfer_config is None: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + else: + cache_size_aligned = kv_cache_tensor.size + alignment + tensor = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + tensor = self._align_memory( + tensor, alignment)[:kv_cache_tensor.size] + kv_cache_raw_tensors[layer_name_inner] = tensor + elif "attn" in layer_name: + # for other attentions, e.g., self_attn, sliding window attn + if self.vllm_config.kv_transfer_config is None: + k_tensor = torch.zeros(kv_cache_tensor.size // 2, + dtype=torch.int8, + device=self.device) + v_tensor = torch.zeros(kv_cache_tensor.size // 2, + dtype=torch.int8, + device=self.device) + else: + cache_size = kv_cache_tensor.size // 2 + cache_size_aligned = kv_cache_tensor.size // 2 + alignment + k_tensor = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + v_tensor = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + k_tensor = self._align_memory(k_tensor, + alignment)[:cache_size] + v_tensor = self._align_memory(v_tensor, + alignment)[:cache_size] + kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor) + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may + # encounter OOM issue + if isinstance(kv_cache_spec, FullAttentionSpec): + raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name] + assert raw_k_tensor is not None + assert raw_v_tensor is not None + assert (raw_k_tensor.numel() + raw_v_tensor.numel() + ) % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel() + ) // kv_cache_spec.page_size_bytes + + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks + + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and self.use_hybrid_blocks: + block_size = attn_backend.get_supported_block_size()[0] + + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:]) + k_cache = self._convert_torch_format(k_cache) + v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:]) + v_cache = self._convert_torch_format(v_cache) + kv_caches[layer_name] = (k_cache, v_cache) + elif isinstance(kv_cache_spec, MambaSpec): + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor is not None + assert raw_tensor.numel( + ) % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes + + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks + + state_tensors = [] + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + kv_caches[layer_name] = state_tensors + else: + raise ValueError("Unknown KV cache spec type.") + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + + def may_reinitialize_input_batch(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec) + ] + + # Generate kernel_block_sizes that matches each block_size + # For attention backends that support virtual block splitting, + # use the supported block sizes from the backend + # For other backends (like Mamba), use [0] (no splitting) + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups): + + if isinstance(kv_cache_group.kv_cache_spec, + EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # the backend. + try: + attn_groups = self.attn_groups[kv_cache_group_id] + except IndexError: + attn_groups = None + if attn_groups and self.use_hybrid_blocks: + # Use the backend's supported block size list + backend = attn_groups[0].backend + supported_sizes = backend.get_supported_block_size() + # If no specific sizes supported, use cache config + # block_size + kernel_block_size_list = (supported_sizes + if supported_sizes else + [self.cache_config.block_size]) + else: + # Fallback to cache config block_size if no backend found + kernel_block_size_list = [self.cache_config.block_size] + kernel_block_sizes.append(kernel_block_size_list) + else: + # This is likely Mamba or other non-attention cache, + # no splitting. + # NOTE: set kernel_block_sizes to 0 to disable slotmapping computation + # of mamba block. In this case, BlockTable.block_size will never equal + # to kernel_block_sizes[0] + kernel_block_sizes.append([0]) + + if block_sizes != [ + self.cache_config.block_size + ] or kernel_block_sizes != [[self.cache_config.block_size]]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), + kernel_block_sizes=kernel_block_sizes, + ) + + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + + class AttentionGroupKey(NamedTuple): + attn_backend: type[AttentionBackend] + kv_cache_spec: KVCacheSpec + + def get_attn_backends_for_group( + kv_cache_group_spec: KVCacheGroupSpec, + ) -> dict[AttentionGroupKey, list[str]]: + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, + kv_cache_group_spec.layer_names) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in kv_cache_group_spec.layer_names: + attn_backend = layers[layer_name].get_attn_backend() + full_cls_name = attn_backend.full_cls_name() + layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ + layer_name] + key = (full_cls_name, layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey(attn_backend, + layer_kv_cache_spec) + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for (attn_backend, + kv_cache_spec), layer_names in attn_backends_map.items(): + attn_metadata_builders = [] + attn_metadata_builders.append(attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + )) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builders, + layer_names, kv_cache_spec) + attn_groups.append(attn_group) + return attn_groups + + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + attn_backends = get_attn_backends_for_group( # type: ignore + kv_cache_group_spec) + self.attn_groups.append(create_attn_groups(attn_backends)) + + # Calculate reorder batch threshold (if needed) + self.calculate_reorder_batch_threshold() + + def _attn_group_iterator(self) -> Iterator[AttentionGroup]: + return itertools.chain.from_iterable(self.attn_groups) + + def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: + if not self.kv_cache_config.kv_cache_groups: + return + for attn_groups in self.attn_groups: + yield from attn_groups + + def calculate_reorder_batch_threshold(self) -> None: + """ + Check that if any backends reorder batches; that the reordering + is compatible (e.g., decode threshold is the same) + """ + for group in self._attn_group_iterator(): + attn_metadata_builder_i = group.get_metadata_builder() + if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"): + # check that if any backends reorder batches; that the reordering + # is compatible (e.g., decode threshold is the same) + reorder_batch_threshold_i = ( + attn_metadata_builder_i.reorder_batch_threshold) + if reorder_batch_threshold_i is not None: + if self.reorder_batch_threshold is not None: + if reorder_batch_threshold_i != \ + self.reorder_batch_threshold: + raise ValueError( + f"Attention backend reorders decodes with " + f"threshold {reorder_batch_threshold_i} but other " + f"backend uses threshold " + f"{self.reorder_batch_threshold}") + else: + self.reorder_batch_threshold = reorder_batch_threshold_i + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + use_sparse = self.use_sparse + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + if isinstance(attn_module, AscendMultiHeadLatentAttention): + continue + + # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + if use_mla and not use_sparse: + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=self.cache_config.cache_dtype) + else: + # TODO(cmq): This is a hack way to fix deepseek kvcache when + # using DSA. Fix the spec in vLLM is a finnal way. + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) + if len(mamba_layers) > 0: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) + + return kv_cache_spec + + def initialize_aclgraph_capture(self) -> None: + min_ag_support = AttentionCGSupport.ALWAYS + min_ag_builder_name = None + + for attn_group in self._attn_group_iterator(): + builder = attn_group.get_metadata_builder() + if builder.aclgraph_support.value < min_ag_support.value: + min_ag_support = builder.aclgraph_support + min_ag_builder_name = builder.__class__.__name__ + + # This is an imitation of compilation_config.splitting_ops_contain_attention() + splitting_ops_contain_attention = ( + self.compilation_config.splitting_ops is not None + and all(op in self.compilation_config.splitting_ops for op in [ + "vllm.unified_ascend_attention_with_output", + "vllm.mla_forward", + ])) + + # Flexible resolve the aclgraph mode + aclgraph_mode = self.compilation_config.cudagraph_mode + # check graph for mixed batch is supported + if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_ag_support != AttentionCGSupport.ALWAYS: + msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported " + f"with {min_ag_builder_name} backend (support: " + f"{min_ag_support})") + if min_ag_support == AttentionCGSupport.NEVER: + # if not supported any full graphs, just raise it. + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" + raise ValueError(msg) + + # attempt to resolve the full graph related mode + if splitting_ops_contain_attention: + msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" + aclgraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_AND_PIECEWISE) + else: + msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" + aclgraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_DECODE_ONLY) + logger.warning(msg) + + # double check that we can support full graph if they are requested + # even after automatic downgrades + if aclgraph_mode.has_full_cudagraphs() \ + and min_ag_support == AttentionCGSupport.NEVER: + raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not " + f"supported with {min_ag_builder_name} backend (" + f"support:{min_ag_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise") + + self.aclgraph_dispatcher.initialize_cudagraph_keys( + self.compilation_config.cudagraph_mode, + self.uniform_decode_query_len) + + def _capture_aclgraphs(self, compilation_cases: list[int], + aclgraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool): + assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \ + aclgraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] + + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + logger.info( + "Starting to capture ACL graphs for cases: %s, " + "mode: %s, uniform_decode: %s", compilation_cases, + aclgraph_runtime_mode.name, uniform_decode) + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing ACL graphs ({}, {})".format( + "decode" if uniform_decode else "mixed prefill-decode", + aclgraph_runtime_mode.name)) + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + aclgraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode) + self._dummy_run(num_tokens, + aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, + uniform_decode=uniform_decode) + + def _capture_model(self): + if not self.use_aclgraph: + logger.warning( + "Skipping ACL graph capture. To turn on ACL graph capture, " + "ensure `aclraph_mode` was not manually set to `NONE`") + return + else: + self.initialize_aclgraph_capture() + + set_cudagraph_capturing_enabled(True) + # Trigger ACL graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + aclgraph_mode = self.compilation_config.cudagraph_mode + if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE: + aclgraph_runtime_mode = aclgraph_mode.mixed_mode() + + compilation_cases = list(reversed(self.aclgraph_batch_sizes)) + + try: + self._capture_aclgraphs( + compilation_cases, + aclgraph_runtime_mode=aclgraph_runtime_mode, + uniform_decode=False) + except Exception as e: + error_msg = str(e) + error_code = '0x7020023' + pattern = r'retCode=([^,\s\.]+)' + match = re.search(pattern, error_msg) + if match: + retCode = match.group(1) + # Determine whether the error message is caused by stream capture failure. + if match and retCode == error_code: + logger.error( + f"ACLgraph sizes capture fail: {type(e).__name__}:\n" + "ACLgraph has insufficient available streams to capture the configured number of sizes. " + "Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n" + "Recommended solutions:\n" + "1. Manually configure the compilation_config parameter " + "with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n" + "2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n" + f"{str(e)}") + raise + + if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + aclgraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.aclgraph_batch_sizes if x <= max_num_tokens + and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + if not all(x % self.uniform_decode_query_len == 0 + for x in decode_cudagraph_batch_sizes): + raise ValueError( + "In the MTP fullgraph scenario, each graph size must be an integer multiple of " + f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. " + f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, " + f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. " + "For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]." + ) + self._capture_aclgraphs( + compilation_cases=compilation_cases_decode, + aclgraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) + + # Disable aclgraph capturing globally, so any unexpected aclgraph + # capturing will be detected and raise an error after here. + # Note: We don't put it into graph_capture context manager because + # we may doing lazy capturing in future that still allows capturing + # after here. + set_cudagraph_capturing_enabled(False) + + def capture_model(self) -> None: + + compilation_counter.num_gpu_runner_capture_triggers += 1 + + start_time = time.perf_counter() + start_free_npu_memory = torch.npu.mem_get_info()[0] + + self._capture_model() + + end_time = time.perf_counter() + end_free_npu_memory = torch.npu.mem_get_info()[0] + elapsed_time = end_time - start_time + npu_graph_size = start_free_npu_memory - end_free_npu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, npu_graph_size / (1 << 30)) + + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + scheduler_output: "SchedulerOutput", + ) -> dict[str, Optional[LogprobsTensors]]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc_np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer NPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( + token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, + non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( + ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking NPU->CPU transfers. + if prompt_logprobs_dict: + torch.npu.synchronize() + + return prompt_logprobs_dict + + def get_supported_pooling_tasks(self): + model = self.get_model() + if not is_pooling_model(model): + return [] + + return list(model.pooler.get_supported_tasks()) + + def _build_drafter_prepare_inputs_torchair_param(self): + return False diff --git a/vllm_npu/worker/npu_input_batch.py b/vllm_npu/worker/npu_input_batch.py new file mode 100644 index 0000000..d859032 --- /dev/null +++ b/vllm_npu/worker/npu_input_batch.py @@ -0,0 +1,842 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py +# + +from dataclasses import dataclass +from typing import Optional, cast + +import numpy as np +import torch +from typing_extensions import deprecated +from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, + MultiModalKwargsItems, PlaceholderRange) +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import swap_dict_values +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, + LogitsProcessors, + MoveDirectionality) +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.utils import is_spec_decode_unsupported +from vllm.v1.utils import copy_slice + +from vllm_npu.worker.block_table import MultiGroupBlockTable + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: list[int] + sampling_params: Optional[SamplingParams] + pooling_params: Optional[PoolingParams] + generator: Optional[torch.Generator] + + block_ids: tuple[list[int], ...] + num_computed_tokens: int + output_token_ids: list[int] + + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + + mm_features: Optional[list[MultiModalFeatureSpec]] = None + # for back-compatibility, will be removed in next major release + mm_kwargs: Optional[list[MultiModalKwargsItem]] = None + mm_positions: Optional[list[PlaceholderRange]] = None + mm_hashes: Optional[list[PlaceholderRange]] = None + + lora_request: Optional[LoRARequest] = None + + def __post_init__(self): + self.num_prompt_tokens = len(self.prompt_token_ids) + + @property + def num_tokens(self) -> int: + return self.num_prompt_tokens + len(self.output_token_ids) + + # Temporary back-compatibility for plugins that define model runner + @property + @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " + "removed in v0.13. Please use `mm_kwargs` instead.") + def mm_inputs(self) -> list[MultiModalKwargsItems]: + assert self.mm_features is not None + return [ + MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features + if f.data is not None + ] + + def get_token_id(self, idx: int) -> int: + if idx < self.num_prompt_tokens: + return self.prompt_token_ids[idx] + else: + return self.output_token_ids[idx - self.num_prompt_tokens] + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, + is_spec_decode: bool = False, + is_pooling_model: bool = False, + num_speculative_tokens: int = 0, + kernel_block_sizes: Optional[list[list[int]]] = None): + self.is_pooling_model = is_pooling_model + self.is_spec_decode = is_spec_decode + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens + self.device = device + self.pin_memory = pin_memory + self.vocab_size = vocab_size + + self._req_ids: list[Optional[str]] = [] + self.req_id_to_index: dict[str, int] = {} + + # TODO(woosuk): This buffer could be too large if max_model_len is big. + # Find a way to reduce the CPU memory usage. + # This buffer is not directly transferred to the NPU, so it does not + # need to be pinned. + self.token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + pin_memory=False, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu_tensor = torch.zeros( + (max_num_reqs, ), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_computed_tokens_cpu = \ + self.num_computed_tokens_cpu_tensor.numpy() + + # Block table. + self.block_table = MultiGroupBlockTable( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=pin_memory, + device=device, + block_sizes=block_sizes, + num_speculative_tokens=num_speculative_tokens, + kernel_sizes=kernel_block_sizes) + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: set[str] = set() + self.random_reqs: set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: set[str] = set() + + # IDs of requests which do not support spec decoding + self.spec_decode_unsupported_reqs: set[str] = set() + + # Frequency penalty related data structures + self.frequency_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: set[str] = set() + + # Presence penalty related data structures + self.presence_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + ) + self.presence_penalties_reqs: set[str] = set() + + # Repetition penalty related data structures + self.repetition_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu = \ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: set[str] = set() + + # Speculative decoding + self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), + dtype=torch.int64, + device="cpu", + pin_memory=pin_memory) + self.num_accepted_tokens_cpu = \ + self.num_accepted_tokens_cpu_tensor.numpy() + + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_id_to_request_ids: dict[int, set[str]] = {} + self.lora_id_to_lora_request: dict[int, LoRARequest] = {} + + # req_index -> generator + # NOTE(woosuk): The indices of the requests that do not have their own + # generator should not be included in the dictionary. + self.generators: dict[int, torch.Generator] = {} + + self.num_logprobs: dict[str, int] = {} + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: dict[str, int] = {} + + # To accumulate prompt logprobs tensor chunks across prefill steps. + self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} + + # Internal representation of per-step batch state changes, used for + # reordering persistent batch and generating logitsprocs batch state + # updates. Should reset each step. + self.batch_update_builder = BatchUpdateBuilder() + + # TODO convert this to LogitsProcessor + self.has_allowed_token_ids: set[str] = set() + # NOTE(lufang): In the mask tensor, if the corresponding token allowed, + # the value is False. Since we use masked_fill_ to set -inf. + self.allowed_token_ids_mask: Optional[torch.Tensor] = None + self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + + # req_index -> bad_words_token_ids + self.bad_words_token_ids: dict[int, list[list[int]]] = {} + + self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, + dtype=bool) + + self.req_output_token_ids: list[Optional[list[int]]] = [] + + # Store provided logitsprocs. If none are provided, initialize empty + # data structure + self.logitsprocs = logitsprocs or LogitsProcessors() + + # This is updated each time the batch constituents change. + self.sampling_metadata = self._make_sampling_metadata() + + self.pooling_params: dict[str, PoolingParams] = {} + + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + + @property + def req_ids(self) -> list[str]: + # None elements should only be present transiently + # while performing state updates to the batch. + return cast(list[str], self._req_ids) + + def _register_add_request(self, request: "CachedRequestState") -> int: + """Track add-request operations for logits processors. + Not applicable to pooling models. + """ + + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs + assert request.sampling_params + + # Fill the next empty index if there is one. + if (new_req_index := self.batch_update_builder.pop_removed()) is None: + # Append to end otherwise. + new_req_index = self.num_reqs + + assert new_req_index < self.max_num_reqs + self.batch_update_builder.added.append( + (new_req_index, request.sampling_params, request.prompt_token_ids, + request.output_token_ids)) + return new_req_index + + def add_request( + self, + request: "CachedRequestState", + ) -> int: + if not self.is_pooling_model: + # New request index bookkeeping for autoregressive models. + req_index = self._register_add_request(request) + else: + req_index = self.num_reqs + + req_id = request.req_id + if req_index == len(self._req_ids): + self._req_ids.append(req_id) + self.req_output_token_ids.append(request.output_token_ids) + else: + self._req_ids[req_index] = req_id + self.req_output_token_ids[req_index] = request.output_token_ids + + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.num_prompt_tokens[req_index] = num_prompt_tokens + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. + self.num_tokens[req_index] = request.num_tokens + # Number of tokens without spec decode tokens. + self.num_tokens_no_spec[req_index] = request.num_tokens + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + self.block_table.add_row(request.block_ids, req_index) + + if sampling_params := request.sampling_params: + if (self.is_spec_decode + and is_spec_decode_unsupported(sampling_params)): + self.spec_decode_unsupported_reqs.add(req_id) + if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 + self.greedy_reqs.add(req_id) + else: + self.temperature_cpu[req_index] = sampling_params.temperature + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: + self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k_cpu[req_index] = top_k + self.frequency_penalties_cpu[ + req_index] = sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[ + req_index] = sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[ + req_index] = sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = (self.vocab_size + if sampling_params.logprobs == -1 + else sampling_params.logprobs) + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[ + req_id] = sampling_params.prompt_logprobs + + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + # False means we don't fill with -inf. + self.allowed_token_ids_mask = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index] = True + # False means we don't fill with -inf. + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = False + + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + elif pooling_params := request.pooling_params: + self.pooling_params[req_id] = pooling_params + self.logits_processing_needs_token_ids[req_index] = ( + pooling_params.requires_token_ids) + else: + raise NotImplementedError(request) + + # Speculative decoding: by default 1 token is generated. + self.num_accepted_tokens_cpu[req_index] = 1 + + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + + return req_index + + def remove_request(self, req_id: str) -> Optional[int]: + """This method must always be followed by a call to condense(). + + Args: + req_id: request to remove + + Returns: + Removed request index, or `None` if `req_id` not recognized + """ + + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + if not self.is_pooling_model: + # Autoregressive models require bookkeeping of removed requests to + # support logitsprocs. + self.batch_update_builder.removed_append(req_index) + self._req_ids[req_index] = None + self.req_output_token_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.spec_decode_unsupported_reqs.discard(req_id) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + self.lora_id_to_request_ids[lora_id].discard(req_id) + if len(self.lora_id_to_request_ids[lora_id]) == 0: + self.lora_id_to_request_ids.pop(lora_id) + self.lora_id_to_lora_request.pop(lora_id) + self.request_lora_mapping[req_index] = 0 + + self.has_allowed_token_ids.discard(req_id) + if self.allowed_token_ids_mask_cpu_tensor is not None: + # False means we don't fill with -inf. + self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) + self.bad_words_token_ids.pop(req_index, None) + self.pooling_params.pop(req_id, None) + return req_index + + def swap_states(self, i1: int, i2: int) -> None: + # For autoregressive models, track detailed request reordering info + # to support logitsprocs + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) + old_id_i1 = self._req_ids[i1] + old_id_i2 = self._req_ids[i2] + self._req_ids[i1], self._req_ids[i2] =\ + self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ + self.req_output_token_ids[i2], self.req_output_token_ids[i1] + assert old_id_i1 is not None and old_id_i2 is not None + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ + self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] + self.num_tokens[i1], self.num_tokens[i2] =\ + self.num_tokens[i2], self.num_tokens[i1] + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ + self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ + self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ + self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] + self.temperature_cpu[i1], self.temperature_cpu[i2] =\ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] =\ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] =\ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ + self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] + + # NOTE: the following is unsafe + # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ + # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] + # instead, we need to temporiarily copy the data for one of the indices + # TODO(lucas): optimize this by only copying valid indices + tmp = self.token_ids_cpu[i1, ...].copy() + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] + self.token_ids_cpu[i2, ...] = tmp + + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) + + self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ + self.request_lora_mapping[i2], self.request_lora_mapping[i1] + + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[i1], \ + self.allowed_token_ids_mask_cpu_tensor[i2] =\ + self.allowed_token_ids_mask_cpu_tensor[i2], \ + self.allowed_token_ids_mask_cpu_tensor[i1] + self.block_table.swap_row(i1, i2) + + def condense(self) -> None: + """Slide non-empty requests down into lower, empty indices. + + Any consecutive empty indices at the very end of the list are not + filled. + + Args: + empty_req_indices: empty indices which may be filled. + + Returns: + swaps: list of (from,to) swap tuples for moved requests + empty_req_indices: indices not filled by condensation + """ + num_reqs = self.num_reqs + + if self.is_pooling_model: + # Will be contiguous in pooling case, just trim the lists. + del self._req_ids[num_reqs:] + del self.req_output_token_ids[num_reqs:] + return + + if not (empty_req_indices := self.batch_update_builder.removed): + # All removed requests were replaced by added requests, or else no + # requests were removed at all. No condense() needed + return + if num_reqs == 0: + # The batched states are empty. + self._req_ids.clear() + self.req_output_token_ids.clear() + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = self.batch_update_builder.peek_removed() + assert empty_index is not None + if empty_index >= last_req_index: + break + + # Move active request down into empty request + # index. + self.batch_update_builder.pop_removed() + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + req_id = self._req_ids[last_req_index] + output_token_ids = self.req_output_token_ids[last_req_index] + assert req_id is not None + self._req_ids[empty_index] = req_id + self._req_ids[last_req_index] = None + self.req_output_token_ids[empty_index] = output_token_ids + self.req_output_token_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + num_tokens = self.num_tokens[last_req_index] + self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ + last_req_index, :num_tokens] + self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table.move_row(last_req_index, empty_index) + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[ + empty_index] = self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[ + empty_index] = self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[ + empty_index] = self.repetition_penalties_cpu[last_req_index] + self.num_accepted_tokens_cpu[ + empty_index] = self.num_accepted_tokens_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + # TODO convert these to LogitsProcessors + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[ + empty_index] = self.allowed_token_ids_mask_cpu_tensor[ + last_req_index] + + bad_words_token_ids = self.bad_words_token_ids.pop( + last_req_index, None) + if bad_words_token_ids is not None: + self.bad_words_token_ids[empty_index] = bad_words_token_ids + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + # Trim lists to the batch size. + del self._req_ids[num_reqs:] + del self.req_output_token_ids[num_reqs:] + + def refresh_metadata(self): + """Apply any batch updates to sampling metadata.""" + + if self.is_pooling_model: + # Batch changes every step for pooling models. + self.sampling_metadata = self._make_sampling_metadata() + return + + # For non-pooling models - generate and apply logitsprocs update; + # reset batch update tracking. + # Update sampling metadata if batch state is changed. + batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) + for logit_proc in self.logitsprocs.all: + logit_proc.update_state(batch_update) + if batch_update: + self.sampling_metadata = self._make_sampling_metadata() + + def _make_sampling_metadata(self) -> SamplingMetadata: + num_reqs = self.num_reqs + if not self.all_greedy: + temperature = copy_slice(self.temperature_cpu_tensor, + self.temperature, num_reqs) + else: + temperature = None + if not self.no_top_p: + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + if not self.no_top_k: + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + copy_slice(self.frequency_penalties_cpu_tensor, + self.frequency_penalties, num_reqs) + copy_slice(self.presence_penalties_cpu_tensor, + self.presence_penalties, num_reqs) + copy_slice(self.repetition_penalties_cpu_tensor, + self.repetition_penalties, num_reqs) + + needs_prompt_token_ids = ( + not self.no_penalties + or self.logits_processing_needs_token_ids[:num_reqs].any()) + if needs_prompt_token_ids: + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. + prompt_token_ids = self._make_prompt_token_ids_tensor() + else: + prompt_token_ids = None + + allowed_token_ids_mask: Optional[torch.Tensor] = None + if not self.no_allowed_token_ids: + assert self.allowed_token_ids_mask is not None + copy_slice(self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, num_reqs) + allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + + return SamplingMetadata( + temperature=temperature, + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(list[list[int]], self.req_output_token_ids), + no_penalties=self.no_penalties, + allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids=self.bad_words_token_ids, + logitsprocs=self.logitsprocs, + ) + + @property + def pooling_metadata(self) -> PoolingMetadata: + if len(self.pooling_params) == 0: + pooling_params = [] + else: + # Note, for now this assumes that all request in the batch + # are either sampling or pooling requests + assert len(self.req_ids) == len(self.pooling_params) + pooling_params = [ + self.pooling_params[req_id] for req_id in self.req_ids + ] + + return PoolingMetadata( + prompt_lens=torch.from_numpy( + self.num_prompt_tokens[:self.num_reqs]), + prompt_token_ids=self.sampling_metadata.prompt_token_ids, + pooling_params=pooling_params, + ) + + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: + max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + prompt_token_ids_cpu_tensor = torch.empty( + (self.num_reqs, max_prompt_len), + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory, + ) + prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() + prompt_token_ids[:] = self.token_ids_cpu[:self. + num_reqs, :max_prompt_len] + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + for i in range(self.num_reqs): + prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, + non_blocking=True) + + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + active_lora_requests: set[LoRARequest] = set( + self.lora_id_to_lora_request.values()) + + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def no_penalties(self) -> bool: + return (len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0) + + @property + def max_num_logprobs(self) -> Optional[int]: + return max(self.num_logprobs.values()) if self.num_logprobs else None + + @property + def no_prompt_logprob(self) -> bool: + return not self.num_prompt_logprobs + + @property + def no_allowed_token_ids(self) -> bool: + return len(self.has_allowed_token_ids) == 0 diff --git a/vllm_npu/worker/worker_v1.py b/vllm_npu/worker/worker_v1.py index 772e524..0281488 100644 --- a/vllm_npu/worker/worker_v1.py +++ b/vllm_npu/worker/worker_v1.py @@ -1,229 +1,369 @@ -""" -NPUWorker — Ascend NPU worker for vLLM v1. +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py +# -Extends the GPU Worker to run on Ascend NPU devices, replacing CUDA -APIs with ``torch.npu`` / ``torch_npu`` equivalents for device -management, memory profiling, and distributed initialization. -""" - -import gc -import os -from typing import TYPE_CHECKING, Any, Optional +import copy +from typing import Optional, Union import torch - +import torch.nn as nn +import torch_npu +import vllm.envs as envs_vllm +from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions +from torch_npu.profiler import dynamic_profile as dp from vllm.config import VllmConfig -from vllm.distributed import ( - ensure_model_parallel_initialized, - init_distributed_environment, -) -from vllm.logger import init_logger +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.logger import logger from vllm.lora.request import LoRARequest -from vllm.platforms import current_platform -from vllm.utils import GiB_bytes, STR_DTYPE_TO_TORCH_DTYPE +from vllm.sequence import IntermediateTensors +from vllm.tasks import SupportedTask +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.worker.worker_base import WorkerBase -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput +import vllm_npu.envs as envs_ascend +from vllm_npu.ascend_config import get_ascend_config, init_ascend_config +from vllm_npu.cpu_binding import bind_cpus +from vllm_npu.device_allocator.camem import CaMemAllocator +from vllm_npu.distributed.parallel_state import init_ascend_model_parallel +from vllm_npu.platform import NPUPlatform +from vllm_npu.utils import (init_ascend_soc_version, is_enable_nz, + register_ascend_customop, sleep_mode_enabled, + try_register_lib) +from vllm_npu.worker.model_runner_v1 import NPUModelRunner -logger = init_logger(__name__) +torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 +from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402 + +torch_non_c_binding_in_graph_functions_npu = dict.fromkeys( + ["torch.npu.current_stream"], + TorchInGraphFunctionVariable, +) # noqa: E402 +torch_non_c_binding_in_graph_functions_npu[ + "torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402 +torch._dynamo.trace_rules.torch_name_rule_map.append( + torch_non_c_binding_in_graph_functions_npu) # noqa: E402 class NPUWorker(WorkerBase): - """Worker running on Ascend NPU devices.""" def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - **kwargs, - ): - super().__init__( - vllm_config=vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + # Additional parameters for compatibility with vllm + **kwargs): + """Initialize the worker for Ascend.""" + # register patch for vllm + from vllm_npu.utils import adapt_patch + adapt_patch() + is_enable_nz(vllm_config=vllm_config) + # Register ops when worker init. + from vllm_npu import ops + ops.register_dummy_fusion_op() + _register_atb_extensions() + register_ascend_customop(vllm_config) + # init ascend config and soc version + init_ascend_config(vllm_config) + init_ascend_soc_version() + use_sparse = False + if vllm_config.model_config is not None: + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + if use_sparse: + # Direct import instead of using try_register_lib to ensure proper error handling when + # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) + # yapf: disable + import custom_ops # type: ignore # noqa + + # yapf: enable + logger.info( + "custom_ops module loaded successfully. Custom operators like " + "torch.ops.custom.npu_sparse_flash_attention are now available." + ) + + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker) + + # binding cpu + if get_ascend_config().enable_cpu_binding: + try: + bind_cpus(self.local_rank, ratio=1.0) + except RuntimeError as e: + logger.error(f"{e} in {self.local_rank}") + except ValueError as e: + logger.error(f"{e} in {self.local_rank}") + except Exception: + logger.info("Skip binding cpu.") + + # Try to import mindie_turbo to accelerate vLLM inference. + try_register_lib( + "mindie_turbo", + "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." ) - - if self.model_config.trust_remote_code: - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Determine cache dtype if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype - ] + self.cache_config.cache_dtype] - self.profiler = None - self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() - # ----------------------------------------------------------------- - # Device initialization - # ----------------------------------------------------------------- + self.profiler = self._init_profiler() + if sleep_mode_enabled(): + # Buffers saved before sleep + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} - def init_device(self) -> None: - """Initialize the NPU device and distributed environment.""" - import torch_npu # noqa: F401 - - os.environ.pop("HCCL_ASYNC_ERROR_HANDLING", None) - - self.device = torch.device(f"npu:{self.local_rank}") - current_platform.set_device(self.device) - current_platform.empty_cache() - - # Record initial memory - self.init_npu_memory, self.total_npu_memory = ( - current_platform.mem_get_info() - ) - - # Initialize distributed (HCCL) - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - distributed_init_method=self.distributed_init_method, - local_rank=self.local_rank, - backend="hccl", - ) - - # Initialize TP / PP parallel groups - ensure_model_parallel_initialized( - tensor_model_parallel_size=( - self.parallel_config.tensor_parallel_size), - pipeline_model_parallel_size=( - self.parallel_config.pipeline_parallel_size), - ) - - # Set random seed - current_platform.seed_everything(self.model_config.seed) - - # NPU memory snapshot - self.requested_memory = ( - self.total_npu_memory * self.cache_config.gpu_memory_utilization - ) - - # Construct model runner - self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device - ) - - # ----------------------------------------------------------------- - # Memory profiling - # ----------------------------------------------------------------- - - @torch.inference_mode() - def determine_available_memory(self) -> int: - """Profile peak memory and return available KV cache memory.""" - import torch_npu # noqa: F401 - - GiB = lambda b: round(b / GiB_bytes, 2) - - current_platform.empty_cache() - gc.collect() - - # Execute a forward pass with dummy inputs to profile memory - self.model_runner.profile_run() - - # Check peak memory - free_npu_memory, _ = current_platform.mem_get_info() - - assert self.init_npu_memory > free_npu_memory, ( - "Error in memory profiling. " - f"Initial free memory {GiB(self.init_npu_memory)} GiB, " - f"current free memory {GiB(free_npu_memory)} GiB." - ) - - # Get peak memory from torch_npu stats - peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] - - current_platform.empty_cache() - torch_allocated = torch_npu.npu.memory_stats()[ - "allocated_bytes.all.current" - ] - total_allocated = ( - self.total_npu_memory - torch_npu.npu.mem_get_info()[0] - ) - non_torch = total_allocated - torch_allocated - if non_torch > 0: - peak_memory += non_torch - - available_kv_cache_memory = int( - self.total_npu_memory * self.cache_config.gpu_memory_utilization - - peak_memory - ) - available_kv_cache_memory = max(available_kv_cache_memory, 0) + # FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170 + from vllm.model_executor.layers.linear import \ + WEIGHT_LOADER_V2_SUPPORTED + if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: + WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + def sleep(self, level: int = 1) -> None: + if not sleep_mode_enabled(): + raise ValueError( + "Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1." + ) + free_bytes_before_sleep = NPUPlatform.mem_get_info()[0] + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() + for name, buffer in model.named_buffers() + } + allocator = CaMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = NPUPlatform.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." logger.info( - "Available KV cache memory: %.2f GiB (total: %.2f GiB)", - GiB(available_kv_cache_memory), - GiB(self.total_npu_memory), - ) + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) - gc.collect() - return available_kv_cache_memory + def wake_up(self, tags: Optional[list[str]] = None) -> None: + if not sleep_mode_enabled(): + raise ValueError( + "Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1." + ) - # ----------------------------------------------------------------- - # Model lifecycle - # ----------------------------------------------------------------- + if is_enable_nz(): + raise ValueError( + "FRACTAL_NZ mode is enabled. This may cause model parameter precision issues " + "in the RL scenarios. Please set vllm_npu_ENABLE_NZ=0.") + allocator = CaMemAllocator.get_instance() + allocator.wake_up(tags=tags) - def load_model(self) -> None: - self.model_runner.load_model() - - def get_model(self): - return self.model_runner.get_model() - - def get_kv_cache_spec(self) -> KVCacheSpec: - return self.model_runner.get_kv_cache_spec() + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Store the number of KV cache blocks.""" self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: - """Allocate KV caches on NPU.""" - self.model_runner.initialize_kv_cache(kv_cache_config) + def _init_device(self): + device = torch.device(f"npu:{self.local_rank}") + NPUPlatform.set_device(device) + NPUPlatform.empty_cache() + self.init_npu_memory = NPUPlatform.mem_get_info()[0] + # Initialize the distributed environment. + self._init_worker_distributed_environment() + # Set random seed. + NPUPlatform.seed_everything(self.model_config.seed) + return device - def compile_or_warm_up_model(self) -> None: - """Warm up the model (no torch.compile on NPU).""" - self.model_runner.capture_model() + def init_device(self): + device = self._init_device() + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = NPUModelRunner(self.vllm_config, device) - # ----------------------------------------------------------------- - # Execution - # ----------------------------------------------------------------- + def determine_available_memory(self) -> int: + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + NPUPlatform.clear_npu_memory() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + _, total_npu_memory = NPUPlatform.mem_get_info() + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + free_npu_memory, _ = NPUPlatform.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_npu_memory > free_npu_memory, ( + "Error in memory profiling. " + f"Initial free memory {self.init_npu_memory}, current free memory" + f" {free_npu_memory}. This happens when the NPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + # Get the peak memory allocation recorded by torch + peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] + # TODO: don`t need impl this func after empty_cache in + # Worker.determine_num_available_blocks() unified` + NPUPlatform.empty_cache() + torch_allocated_bytes = torch_npu.npu.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch_npu.npu.mem_get_info( + )[1] - torch_npu.npu.mem_get_info()[0] + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = int( + total_npu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + available_kv_cache_memory = int(max(available_kv_cache_memory, 0)) + logger.info( + f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}" + ) + return available_kv_cache_memory def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: - output = self.model_runner.execute_model(scheduler_output) - return output if self.is_driver_worker else None + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: + # enable msMonitor to monitor the performance of vllm-ascend + if envs_ascend.MSMONITOR_USE_DAEMON: + dp.step() - def execute_dummy_batch(self) -> None: - self.model_runner.execute_dummy_batch() + intermediate_tensors = None + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) - # ----------------------------------------------------------------- - # Misc - # ----------------------------------------------------------------- + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): + return output - def sleep(self, level: int = 1) -> None: - pass + assert isinstance(output, IntermediateTensors) + parallel_config = self.vllm_config.parallel_config + assert parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank - def wake_up(self, tags: Optional[list[str]] = None) -> None: - pass + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) - def get_supported_tasks(self): - return self.model_runner.get_supported_tasks() + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + def load_model(self) -> None: + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CaMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag="weights") + else: + from contextlib import nullcontext + context = nullcontext() # type: ignore + with context: + self.model_runner.load_model() + + def compile_or_warm_up_model(self) -> None: + # Note: need to adapt for graph mode. + self.model_runner.eplb_warmup() + warmup_sizes = (self.vllm_config.compilation_config.compile_sizes + or []).copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size) + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + # Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) + # may cause performance degradation at runtime. + self._warm_up_atb() + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + NPUPlatform.seed_everything(self.model_config.seed) + + def _warm_up_atb(self): + x = torch.rand((2, 4), dtype=torch.float16).npu() + weight = torch.rand((2, 4), dtype=torch.float16).npu() + c = torch.rand((4, 4), dtype=torch.float32).npu() + torch_npu._npu_matmul_add_fp32(x, weight, c) + + def get_model(self) -> nn.Module: + return self.model_runner.get_model() + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + return self.model_runner.get_kv_cache_spec() + + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate NPU KV cache with the specified kv_cache_config.""" + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CaMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + from contextlib import nullcontext + context = nullcontext() # type: ignore + with context: + self.model_runner.initialize_kv_cache(kv_cache_config) + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) @@ -231,17 +371,72 @@ class NPUWorker(WorkerBase): def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) - def list_loras(self) -> set: + def list_loras(self) -> set[int]: return self.model_runner.list_loras() def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) - def profile(self, is_start: bool = True) -> None: - pass + def execute_dummy_batch(self) -> None: + self.model_runner._dummy_run( + num_tokens=self.model_runner.decode_token_per_req, + uniform_decode=True) - def take_draft_token_ids(self): + def _init_worker_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + init_distributed_environment(self.parallel_config.world_size, + self.rank, self.distributed_init_method, + self.local_rank, "hccl") + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + init_ascend_model_parallel(self.parallel_config) + ensure_kv_transfer_initialized(self.vllm_config) + + def _init_profiler(self): + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs_vllm.VLLM_TORCH_PROFILER_DIR: + if envs_ascend.MSMONITOR_USE_DAEMON: + raise RuntimeError( + "MSMONITOR_USE_DAEMON and VLLM_TORCH_PROFILER_DIR cannot be both set at the same time." + ) + torch_profiler_trace_dir = envs_vllm.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + + experimental_config = torch_npu.profiler._ExperimentalConfig( + export_type=torch_npu.profiler.ExportType.Text, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + msprof_tx=False, + aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone, + l2_cache=False, + op_attr=False, + data_simplification=False, + record_op_args=False, + gc_detect_threshold=None, + ) + + return torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU, + ], + with_stack=envs_vllm.VLLM_TORCH_PROFILER_WITH_STACK, + profile_memory=envs_vllm.\ + VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + with_modules=False, + experimental_config=experimental_config, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir)) + else: + return None + + def get_supported_pooling_tasks(self): + return self.model_runner.get_supported_pooling_tasks() + + def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": + return self.model_runner.get_supported_tasks() + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: return self.model_runner.take_draft_token_ids() - - def check_health(self) -> None: - pass