Files
vllm-npu-plugin/vllm_npu/worker/worker_v1.py
handsomezhuzhu e75504df72 feat: initial vllm-npu-plugin for Ascend NPU adaptation
- NPUPlatform: device management, HCCL process group, config adaptation
- AscendAttentionBackend: npu_fusion_attention (prefill) + npu_incre_flash_attention (decode)
- NPUCommunicator: HCCL-based distributed communication
- NPUWorker: NPU device init, memory profiling
- Custom ops: SiluAndMul, RMS norm, rotary embedding
- Plugin registered via vllm.platform_plugins entry point

Based on vllm-ascend official pattern, targeting Ascend 910B
2026-02-10 11:06:01 +08:00

231 lines
7.2 KiB
Python

"""
NPUWorker — Ascend NPU worker for vLLM v1.
Extends the GPU Worker to run on Ascend NPU devices, replacing CUDA
APIs with ``torch.npu`` / ``torch_npu`` equivalents for device
management, memory profiling, and distributed initialization.
"""
import gc
import os
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.distributed import init_distributed_environment
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
class NPUWorker(WorkerBase):
"""Worker running on Ascend NPU devices."""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
**kwargs,
):
super().__init__(
vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
if self.model_config.trust_remote_code:
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Determine cache dtype
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype
]
self.profiler = None
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# -----------------------------------------------------------------
# Device initialization
# -----------------------------------------------------------------
def init_device(self) -> None:
"""Initialize the NPU device and distributed environment."""
import torch_npu # noqa: F401
os.environ.pop("HCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"npu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.empty_cache()
# Record initial memory
self.init_npu_memory, self.total_npu_memory = (
current_platform.mem_get_info()
)
# Initialize distributed (HCCL)
init_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
"hccl",
)
# Set random seed
current_platform.seed_everything(self.model_config.seed)
# NPU memory snapshot
self.requested_memory = (
self.total_npu_memory * self.cache_config.gpu_memory_utilization
)
# Construct model runner
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device
)
# -----------------------------------------------------------------
# Memory profiling
# -----------------------------------------------------------------
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profile peak memory and return available KV cache memory."""
import torch_npu # noqa: F401
GiB = lambda b: round(b / GiB_bytes, 2)
current_platform.empty_cache()
gc.collect()
# Execute a forward pass with dummy inputs to profile memory
self.model_runner.profile_run()
# Check peak memory
free_npu_memory, _ = current_platform.mem_get_info()
assert self.init_npu_memory > free_npu_memory, (
"Error in memory profiling. "
f"Initial free memory {GiB(self.init_npu_memory)} GiB, "
f"current free memory {GiB(free_npu_memory)} GiB."
)
# Get peak memory from torch_npu stats
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
current_platform.empty_cache()
torch_allocated = torch_npu.npu.memory_stats()[
"allocated_bytes.all.current"
]
total_allocated = (
self.total_npu_memory - torch_npu.npu.mem_get_info()[0]
)
non_torch = total_allocated - torch_allocated
if non_torch > 0:
peak_memory += non_torch
available_kv_cache_memory = int(
self.total_npu_memory * self.cache_config.gpu_memory_utilization
- peak_memory
)
available_kv_cache_memory = max(available_kv_cache_memory, 0)
logger.info(
"Available KV cache memory: %.2f GiB (total: %.2f GiB)",
GiB(available_kv_cache_memory),
GiB(self.total_npu_memory),
)
gc.collect()
return available_kv_cache_memory
# -----------------------------------------------------------------
# Model lifecycle
# -----------------------------------------------------------------
def load_model(self) -> None:
self.model_runner.load_model()
def get_model(self):
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate KV caches on NPU."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
"""Warm up the model (no torch.compile on NPU)."""
self.model_runner.capture_model()
# -----------------------------------------------------------------
# Execution
# -----------------------------------------------------------------
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None
def execute_dummy_batch(self) -> None:
self.model_runner.execute_dummy_batch()
# -----------------------------------------------------------------
# Misc
# -----------------------------------------------------------------
def sleep(self, level: int = 1) -> None:
pass
def wake_up(self, tags: Optional[list[str]] = None) -> None:
pass
def get_supported_tasks(self):
return self.model_runner.get_supported_tasks()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def profile(self, is_start: bool = True) -> None:
pass
def take_draft_token_ids(self):
return self.model_runner.take_draft_token_ids()
def check_health(self) -> None:
pass