mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
222 lines
6.5 KiB
Python
222 lines
6.5 KiB
Python
"""
|
|
NPUPlatform — Ascend NPU platform implementation for vLLM.
|
|
|
|
Implements the ``vllm.platforms.Platform`` interface so that vLLM can
|
|
transparently target Huawei Ascend NPU devices.
|
|
"""
|
|
|
|
import gc
|
|
import os
|
|
from datetime import timedelta
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
from torch.distributed.distributed_c10d import PrefixStore
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import Platform, PlatformEnum
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import ModelConfig, VllmConfig
|
|
from vllm.utils import FlexibleArgumentParser
|
|
else:
|
|
ModelConfig = None
|
|
VllmConfig = None
|
|
FlexibleArgumentParser = None
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class NPUPlatform(Platform):
|
|
"""Out-of-tree platform for Huawei Ascend NPU."""
|
|
|
|
_enum = PlatformEnum.OOT
|
|
device_name: str = "npu"
|
|
device_type: str = "npu"
|
|
dispatch_key: str = "PrivateUse1"
|
|
ray_device_key: str = "NPU"
|
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
|
simple_compile_backend: str = "eager" # torch.compile not supported
|
|
|
|
# -----------------------------------------------------------------
|
|
# Device management
|
|
# -----------------------------------------------------------------
|
|
|
|
@classmethod
|
|
def get_device_capability(cls, device_id: int = 0):
|
|
return None
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
import torch_npu # noqa: F401
|
|
return torch.npu.get_device_name(device_id)
|
|
|
|
@classmethod
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
import torch_npu # noqa: F401
|
|
_, total = torch.npu.mem_get_info(device_id)
|
|
return total
|
|
|
|
@classmethod
|
|
def inference_mode(cls):
|
|
return torch.inference_mode()
|
|
|
|
@classmethod
|
|
def set_device(cls, device: torch.device):
|
|
import torch_npu # noqa: F401
|
|
torch.npu.set_device(device)
|
|
|
|
@classmethod
|
|
def empty_cache(cls):
|
|
import torch_npu # noqa: F401
|
|
torch.npu.empty_cache()
|
|
|
|
@classmethod
|
|
def synchronize(cls):
|
|
import torch_npu # noqa: F401
|
|
torch.npu.synchronize()
|
|
|
|
@classmethod
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
|
import torch_npu # noqa: F401
|
|
return torch.npu.mem_get_info()
|
|
|
|
@classmethod
|
|
def is_pin_memory_available(cls):
|
|
return True
|
|
|
|
@classmethod
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def get_current_memory_usage(
|
|
cls,
|
|
device: Optional[torch.types.Device] = None,
|
|
) -> float:
|
|
import torch_npu # noqa: F401
|
|
torch.npu.reset_peak_memory_stats(device)
|
|
return torch.npu.max_memory_allocated(device)
|
|
|
|
@classmethod
|
|
def clear_npu_memory(cls):
|
|
import torch_npu # noqa: F401
|
|
gc.collect()
|
|
torch.npu.empty_cache()
|
|
torch.npu.reset_peak_memory_stats()
|
|
|
|
def is_sleep_mode_available(self) -> bool:
|
|
return False
|
|
|
|
# -----------------------------------------------------------------
|
|
# Attention backend routing
|
|
# -----------------------------------------------------------------
|
|
|
|
@classmethod
|
|
def get_attn_backend_cls(
|
|
cls,
|
|
selected_backend,
|
|
head_size,
|
|
dtype,
|
|
kv_cache_dtype,
|
|
block_size,
|
|
use_v1,
|
|
use_mla,
|
|
has_sink=False,
|
|
use_sparse=False,
|
|
):
|
|
return "vllm_npu.attention.attention_v1.AscendAttentionBackend"
|
|
|
|
# -----------------------------------------------------------------
|
|
# Distributed
|
|
# -----------------------------------------------------------------
|
|
|
|
@classmethod
|
|
def get_device_communicator_cls(cls) -> str:
|
|
return "vllm_npu.distributed.communicator.NPUCommunicator"
|
|
|
|
@classmethod
|
|
def stateless_init_device_torch_dist_pg(
|
|
cls,
|
|
backend: str,
|
|
prefix_store: PrefixStore,
|
|
group_rank: int,
|
|
group_size: int,
|
|
timeout: timedelta,
|
|
) -> ProcessGroup:
|
|
"""Create an HCCL-based process group for NPU distributed."""
|
|
from torch.distributed import is_hccl_available
|
|
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
|
|
|
assert is_hccl_available(), (
|
|
"HCCL is not available. Make sure torch_npu is properly installed."
|
|
)
|
|
|
|
pg: ProcessGroup = ProcessGroup(
|
|
prefix_store,
|
|
group_rank,
|
|
group_size,
|
|
)
|
|
|
|
backend_options = ProcessGroupHCCL.Options()
|
|
backend_options._timeout = timeout
|
|
|
|
backend_class = ProcessGroupHCCL(
|
|
prefix_store, group_rank, group_size, backend_options
|
|
)
|
|
device = torch.device("npu")
|
|
backend_class._set_sequence_number_for_group()
|
|
backend_type = ProcessGroup.BackendType.CUSTOM
|
|
|
|
pg._register_backend(device, backend_type, backend_class)
|
|
return pg
|
|
|
|
# -----------------------------------------------------------------
|
|
# Configuration
|
|
# -----------------------------------------------------------------
|
|
|
|
@classmethod
|
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
|
"""Adapt vLLM configuration for NPU hardware."""
|
|
from vllm.config import CompilationLevel
|
|
|
|
# Register NPU custom ops (must happen after platform is detected)
|
|
from vllm_npu import register_npu_ops
|
|
register_npu_ops()
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
cache_config = vllm_config.cache_config
|
|
compilation_config = vllm_config.compilation_config
|
|
|
|
# Set worker class
|
|
if parallel_config and parallel_config.worker_cls == "auto":
|
|
parallel_config.worker_cls = (
|
|
"vllm_npu.worker.worker_v1.NPUWorker"
|
|
)
|
|
|
|
# Set default block size for NPU (aligned to 128)
|
|
if cache_config and cache_config.block_size is None:
|
|
cache_config.block_size = 128
|
|
|
|
# Disable torch.compile on NPU — use eager mode
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
|
|
|
logger.info(
|
|
"NPUPlatform: configuration updated — "
|
|
"worker_cls=%s, block_size=%s, compilation=NO_COMPILATION",
|
|
getattr(parallel_config, "worker_cls", "N/A"),
|
|
getattr(cache_config, "block_size", "N/A"),
|
|
)
|
|
|
|
@classmethod
|
|
def supports_v1(cls, model_config: "ModelConfig") -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def support_hybrid_kv_cache(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def support_static_graph_mode(cls) -> bool:
|
|
return False
|