mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
feat: initial vllm-npu-plugin for Ascend NPU adaptation
- NPUPlatform: device management, HCCL process group, config adaptation - AscendAttentionBackend: npu_fusion_attention (prefill) + npu_incre_flash_attention (decode) - NPUCommunicator: HCCL-based distributed communication - NPUWorker: NPU device init, memory profiling - Custom ops: SiluAndMul, RMS norm, rotary embedding - Plugin registered via vllm.platform_plugins entry point Based on vllm-ascend official pattern, targeting Ascend 910B
This commit is contained in:
217
vllm_npu/platform.py
Normal file
217
vllm_npu/platform.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
NPUPlatform — Ascend NPU platform implementation for vLLM.
|
||||
|
||||
Implements the ``vllm.platforms.Platform`` interface so that vLLM can
|
||||
transparently target Huawei Ascend NPU devices.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import PrefixStore
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
FlexibleArgumentParser = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NPUPlatform(Platform):
|
||||
"""Out-of-tree platform for Huawei Ascend NPU."""
|
||||
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name: str = "npu"
|
||||
device_type: str = "npu"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
ray_device_key: str = "NPU"
|
||||
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
||||
simple_compile_backend: str = "eager" # torch.compile not supported
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Device management
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
import torch_npu # noqa: F401
|
||||
return torch.npu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
import torch_npu # noqa: F401
|
||||
_, total = torch.npu.mem_get_info(device_id)
|
||||
return total
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.inference_mode()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device):
|
||||
import torch_npu # noqa: F401
|
||||
torch.npu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls):
|
||||
import torch_npu # noqa: F401
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def synchronize(cls):
|
||||
import torch_npu # noqa: F401
|
||||
torch.npu.synchronize()
|
||||
|
||||
@classmethod
|
||||
def mem_get_info(cls) -> Tuple[int, int]:
|
||||
import torch_npu # noqa: F401
|
||||
return torch.npu.mem_get_info()
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls,
|
||||
device: Optional[torch.types.Device] = None,
|
||||
) -> float:
|
||||
import torch_npu # noqa: F401
|
||||
torch.npu.reset_peak_memory_stats(device)
|
||||
return torch.npu.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def clear_npu_memory(cls):
|
||||
import torch_npu # noqa: F401
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
def is_sleep_mode_available(self) -> bool:
|
||||
return False
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Attention backend routing
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
):
|
||||
return "vllm_npu.attention.attention_v1.AscendAttentionBackend"
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Distributed
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm_npu.distributed.communicator.NPUCommunicator"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
"""Create an HCCL-based process group for NPU distributed."""
|
||||
from torch.distributed import is_hccl_available
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
|
||||
assert is_hccl_available(), (
|
||||
"HCCL is not available. Make sure torch_npu is properly installed."
|
||||
)
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
|
||||
backend_options = ProcessGroupHCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupHCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
device = torch.device("npu")
|
||||
backend_class._set_sequence_number_for_group()
|
||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Configuration
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""Adapt vLLM configuration for NPU hardware."""
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# Set worker class
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = (
|
||||
"vllm_npu.worker.worker_v1.NPUWorker"
|
||||
)
|
||||
|
||||
# Set default block size for NPU (aligned to 128)
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
# Disable torch.compile on NPU — use eager mode
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
logger.info(
|
||||
"NPUPlatform: configuration updated — "
|
||||
"worker_cls=%s, block_size=%s, compilation=NO_COMPILATION",
|
||||
getattr(parallel_config, "worker_cls", "N/A"),
|
||||
getattr(cache_config, "block_size", "N/A"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return False
|
||||
Reference in New Issue
Block a user