Files
vllm-npu-plugin/vllm_npu/platform.py
handsomezhuzhu e75504df72 feat: initial vllm-npu-plugin for Ascend NPU adaptation
- NPUPlatform: device management, HCCL process group, config adaptation
- AscendAttentionBackend: npu_fusion_attention (prefill) + npu_incre_flash_attention (decode)
- NPUCommunicator: HCCL-based distributed communication
- NPUWorker: NPU device init, memory profiling
- Custom ops: SiluAndMul, RMS norm, rotary embedding
- Plugin registered via vllm.platform_plugins entry point

Based on vllm-ascend official pattern, targeting Ascend 910B
2026-02-10 11:06:01 +08:00

218 lines
6.4 KiB
Python

"""
NPUPlatform — Ascend NPU platform implementation for vLLM.
Implements the ``vllm.platforms.Platform`` interface so that vLLM can
transparently target Huawei Ascend NPU devices.
"""
import gc
import os
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import init_logger
from vllm.platforms import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
FlexibleArgumentParser = None
logger = init_logger(__name__)
class NPUPlatform(Platform):
"""Out-of-tree platform for Huawei Ascend NPU."""
_enum = PlatformEnum.OOT
device_name: str = "npu"
device_type: str = "npu"
dispatch_key: str = "PrivateUse1"
ray_device_key: str = "NPU"
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
simple_compile_backend: str = "eager" # torch.compile not supported
# -----------------------------------------------------------------
# Device management
# -----------------------------------------------------------------
@classmethod
def get_device_capability(cls, device_id: int = 0):
return None
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
import torch_npu # noqa: F401
return torch.npu.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
import torch_npu # noqa: F401
_, total = torch.npu.mem_get_info(device_id)
return total
@classmethod
def inference_mode(cls):
return torch.inference_mode()
@classmethod
def set_device(cls, device: torch.device):
import torch_npu # noqa: F401
torch.npu.set_device(device)
@classmethod
def empty_cache(cls):
import torch_npu # noqa: F401
torch.npu.empty_cache()
@classmethod
def synchronize(cls):
import torch_npu # noqa: F401
torch.npu.synchronize()
@classmethod
def mem_get_info(cls) -> Tuple[int, int]:
import torch_npu # noqa: F401
return torch.npu.mem_get_info()
@classmethod
def is_pin_memory_available(cls):
return True
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@classmethod
def get_current_memory_usage(
cls,
device: Optional[torch.types.Device] = None,
) -> float:
import torch_npu # noqa: F401
torch.npu.reset_peak_memory_stats(device)
return torch.npu.max_memory_allocated(device)
@classmethod
def clear_npu_memory(cls):
import torch_npu # noqa: F401
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
def is_sleep_mode_available(self) -> bool:
return False
# -----------------------------------------------------------------
# Attention backend routing
# -----------------------------------------------------------------
@classmethod
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink=False,
use_sparse=False,
):
return "vllm_npu.attention.attention_v1.AscendAttentionBackend"
# -----------------------------------------------------------------
# Distributed
# -----------------------------------------------------------------
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_npu.distributed.communicator.NPUCommunicator"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
"""Create an HCCL-based process group for NPU distributed."""
from torch.distributed import is_hccl_available
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
assert is_hccl_available(), (
"HCCL is not available. Make sure torch_npu is properly installed."
)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(
prefix_store, group_rank, group_size, backend_options
)
device = torch.device("npu")
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg
# -----------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""Adapt vLLM configuration for NPU hardware."""
from vllm.config import CompilationLevel
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
# Set worker class
if parallel_config and parallel_config.worker_cls == "auto":
parallel_config.worker_cls = (
"vllm_npu.worker.worker_v1.NPUWorker"
)
# Set default block size for NPU (aligned to 128)
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128
# Disable torch.compile on NPU — use eager mode
compilation_config.level = CompilationLevel.NO_COMPILATION
logger.info(
"NPUPlatform: configuration updated — "
"worker_cls=%s, block_size=%s, compilation=NO_COMPILATION",
getattr(parallel_config, "worker_cls", "N/A"),
getattr(cache_config, "block_size", "N/A"),
)
@classmethod
def supports_v1(cls, model_config: "ModelConfig") -> bool:
return True
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return False
@classmethod
def support_static_graph_mode(cls) -> bool:
return False