""" NPUWorker — Ascend NPU worker for vLLM v1. Extends the GPU Worker to run on Ascend NPU devices, replacing CUDA APIs with ``torch.npu`` / ``torch_npu`` equivalents for device management, memory profiling, and distributed initialization. """ import gc import os from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import VllmConfig from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform from vllm.utils import GiB_bytes, STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) class NPUWorker(WorkerBase): """Worker running on Ascend NPU devices.""" def __init__( self, vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, is_driver_worker: bool = False, **kwargs, ): super().__init__( vllm_config=vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) if self.model_config.trust_remote_code: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() # Determine cache dtype if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype ] self.profiler = None self._sleep_saved_buffers: dict[str, torch.Tensor] = {} # ----------------------------------------------------------------- # Device initialization # ----------------------------------------------------------------- def init_device(self) -> None: """Initialize the NPU device and distributed environment.""" import torch_npu # noqa: F401 os.environ.pop("HCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"npu:{self.local_rank}") current_platform.set_device(self.device) current_platform.empty_cache() # Record initial memory self.init_npu_memory, self.total_npu_memory = ( current_platform.mem_get_info() ) # Initialize distributed (HCCL) init_distributed_environment( world_size=self.parallel_config.world_size, rank=self.rank, distributed_init_method=self.distributed_init_method, local_rank=self.local_rank, backend="hccl", ) # Initialize TP / PP parallel groups ensure_model_parallel_initialized( tensor_model_parallel_size=( self.parallel_config.tensor_parallel_size), pipeline_model_parallel_size=( self.parallel_config.pipeline_parallel_size), ) # Set random seed current_platform.seed_everything(self.model_config.seed) # NPU memory snapshot self.requested_memory = ( self.total_npu_memory * self.cache_config.gpu_memory_utilization ) # Construct model runner self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device ) # ----------------------------------------------------------------- # Memory profiling # ----------------------------------------------------------------- @torch.inference_mode() def determine_available_memory(self) -> int: """Profile peak memory and return available KV cache memory.""" import torch_npu # noqa: F401 GiB = lambda b: round(b / GiB_bytes, 2) current_platform.empty_cache() gc.collect() # Execute a forward pass with dummy inputs to profile memory self.model_runner.profile_run() # Check peak memory free_npu_memory, _ = current_platform.mem_get_info() assert self.init_npu_memory > free_npu_memory, ( "Error in memory profiling. " f"Initial free memory {GiB(self.init_npu_memory)} GiB, " f"current free memory {GiB(free_npu_memory)} GiB." ) # Get peak memory from torch_npu stats peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] current_platform.empty_cache() torch_allocated = torch_npu.npu.memory_stats()[ "allocated_bytes.all.current" ] total_allocated = ( self.total_npu_memory - torch_npu.npu.mem_get_info()[0] ) non_torch = total_allocated - torch_allocated if non_torch > 0: peak_memory += non_torch available_kv_cache_memory = int( self.total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory ) available_kv_cache_memory = max(available_kv_cache_memory, 0) logger.info( "Available KV cache memory: %.2f GiB (total: %.2f GiB)", GiB(available_kv_cache_memory), GiB(self.total_npu_memory), ) gc.collect() return available_kv_cache_memory # ----------------------------------------------------------------- # Model lifecycle # ----------------------------------------------------------------- def load_model(self) -> None: self.model_runner.load_model() def get_model(self): return self.model_runner.get_model() def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Store the number of KV cache blocks.""" self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate KV caches on NPU.""" self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: """Warm up the model (no torch.compile on NPU).""" self.model_runner.capture_model() # ----------------------------------------------------------------- # Execution # ----------------------------------------------------------------- def execute_model( self, scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) return output if self.is_driver_worker else None def execute_dummy_batch(self) -> None: self.model_runner.execute_dummy_batch() # ----------------------------------------------------------------- # Misc # ----------------------------------------------------------------- def sleep(self, level: int = 1) -> None: pass def wake_up(self, tags: Optional[list[str]] = None) -> None: pass def get_supported_tasks(self): return self.model_runner.get_supported_tasks() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) def list_loras(self) -> set: return self.model_runner.list_loras() def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) def profile(self, is_start: bool = True) -> None: pass def take_draft_token_ids(self): return self.model_runner.take_draft_token_ids() def check_health(self) -> None: pass