commit e75504df72eda5c122a07ee8a910671038231b70 Author: handsomezhuzhu <2658601135@qq.com> Date: Tue Feb 10 11:06:01 2026 +0800 feat: initial vllm-npu-plugin for Ascend NPU adaptation - NPUPlatform: device management, HCCL process group, config adaptation - AscendAttentionBackend: npu_fusion_attention (prefill) + npu_incre_flash_attention (decode) - NPUCommunicator: HCCL-based distributed communication - NPUWorker: NPU device init, memory profiling - Custom ops: SiluAndMul, RMS norm, rotary embedding - Plugin registered via vllm.platform_plugins entry point Based on vllm-ascend official pattern, targeting Ascend 910B diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9d93d01 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +build/ +*.egg + +# Virtual env +.venv/ +venv/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..a6b6a46 --- /dev/null +++ b/README.md @@ -0,0 +1,95 @@ +# vllm-npu-plugin + +Ascend NPU platform plugin for vLLM v0.11.0. + +## Overview + +This package registers as an out-of-tree vLLM platform plugin via the `vllm.platform_plugins` entry-point group, enabling vLLM to run on Huawei Ascend NPU devices. + +### Components + +| Module | Description | +|---|---| +| `vllm_npu/platform.py` | `NPUPlatform` — device management, attention backend routing, config adaptation | +| `vllm_npu/distributed/communicator.py` | `NPUCommunicator` — HCCL-based distributed communication | +| `vllm_npu/attention/attention_v1.py` | `AscendAttentionBackend` — FlashAttention NPU kernels (prefill + decode) | +| `vllm_npu/worker/worker_v1.py` | `NPUWorker` — NPU device initialization and memory profiling | +| `vllm_npu/ops/` | NPU-optimized ops (SiLU+Mul, RMS norm, rotary embedding) | + +## Prerequisites + +- **Hardware**: Huawei Ascend 910B/910C or compatible NPU +- **Software**: + - CANN (Compute Architecture for Neural Networks) 8.0+ + - `torch_npu` matching your PyTorch version + - vLLM v0.11.0 (installed from source) + +## Installation + +```bash +# 1. Ensure vLLM v0.11.0 is installed with the feat/ascend-npu-adapt-v0.11.0 branch +cd /path/to/vllm +pip install -e . + +# 2. Install this plugin +cd /path/to/vllm_npu_plugin +pip install -e . +``` + +## Verification + +```bash +# Verify plugin is discoverable +python -c " +from vllm.plugins import load_plugins_by_group +plugins = load_plugins_by_group('vllm.platform_plugins') +print('Discovered plugins:', list(plugins.keys())) +assert 'npu' in plugins, 'NPU plugin not found!' +print('NPU plugin registered successfully!') +" + +# Verify platform detection (requires NPU hardware) +python -c " +from vllm.platforms import current_platform +print(f'Current platform: {current_platform}') +print(f'Device type: {current_platform.device_type}') +" +``` + +## Usage + +Once installed, vLLM will automatically detect the NPU platform if Ascend hardware is available: + +```bash +# Run inference on NPU +python -m vllm.entrypoints.openai.api_server \ + --model /path/to/model \ + --tensor-parallel-size 1 \ + --block-size 128 +``` + +## Architecture + +``` +vllm (v0.11.0) vllm-npu-plugin +┌─────────────────┐ ┌─────────────────────┐ +│ Platform Plugin │──entry_point──│ register() │ +│ Discovery │ │ → NPUPlatform │ +├─────────────────┤ ├─────────────────────┤ +│ AttentionBackend │◄──routing─────│ AscendAttentionBackend │ +│ Interface │ │ ├─ npu_fusion_attention│ +│ │ │ └─ npu_incre_flash_attn│ +├─────────────────┤ ├─────────────────────┤ +│ Worker Interface │◄──worker_cls──│ NPUWorker │ +│ │ │ ├─ HCCL distributed │ +│ │ │ └─ NPU memory mgmt │ +└─────────────────┘ └─────────────────────┘ +``` + +## Key API References + +- **`torch_npu.npu_fusion_attention`** — Fused multi-head attention (prefill) +- **`torch_npu.npu_incre_flash_attention`** — Incremental flash attention (decode) +- **`torch_npu._npu_reshape_and_cache`** — KV cache update +- **`torch_npu.npu_rms_norm`** / `npu_add_rms_norm` — Layer normalization +- **`torch_npu.npu_swiglu`** — Fused SiLU + Mul activation diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..be5731c --- /dev/null +++ b/setup.py @@ -0,0 +1,24 @@ +""" +vllm-npu-plugin: Ascend NPU platform plugin for vLLM v0.11.0. + +Registers as an out-of-tree platform plugin via the +``vllm.platform_plugins`` entry-point group. +""" + +from setuptools import find_packages, setup + +setup( + name="vllm-npu-plugin", + version="0.1.0", + description="Ascend NPU platform plugin for vLLM v0.11.0", + packages=find_packages(), + python_requires=">=3.9", + install_requires=[ + # vllm must already be installed (v0.11.0) + ], + entry_points={ + "vllm.platform_plugins": [ + "npu = vllm_npu:register", + ], + }, +) diff --git a/vllm_npu/__init__.py b/vllm_npu/__init__.py new file mode 100644 index 0000000..89af8be --- /dev/null +++ b/vllm_npu/__init__.py @@ -0,0 +1,12 @@ +""" +vllm_npu — Ascend NPU platform plugin for vLLM. + +The ``register()`` function is discovered by vLLM through the +``vllm.platform_plugins`` entry-point and returns the fully-qualified +class name of the platform implementation. +""" + + +def register(): + """Return the fully-qualified name of the NPU platform class.""" + return "vllm_npu.platform.NPUPlatform" diff --git a/vllm_npu/attention/__init__.py b/vllm_npu/attention/__init__.py new file mode 100644 index 0000000..d342775 --- /dev/null +++ b/vllm_npu/attention/__init__.py @@ -0,0 +1 @@ +"""Ascend NPU attention backends.""" diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py new file mode 100644 index 0000000..859df09 --- /dev/null +++ b/vllm_npu/attention/attention_v1.py @@ -0,0 +1,575 @@ +""" +Ascend NPU attention backend for vLLM v1. + +Implements the ``AttentionBackend``, ``AttentionMetadata``, +``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using +Huawei Ascend NPU FlashAttention operators: + +- ``torch_npu.npu_fusion_attention`` — fused multi-head attention +- ``torch_npu._npu_reshape_and_cache`` — KV cache update +- ``torch_npu.npu_incre_flash_attention`` — paged-attention decode +""" + +from dataclasses import dataclass +from enum import IntEnum +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, +) +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# ===================================================================== +# Attention state enum +# ===================================================================== + + +class AscendAttentionState(IntEnum): + """Attention computation state, determines the kernel path.""" + PrefillNoCache = 0 + PrefillCacheHit = 1 + DecodeOnly = 2 + ChunkedPrefill = 3 + + +# ===================================================================== +# Backend class +# ===================================================================== + + +class AscendAttentionBackend(AttentionBackend): + """Ascend NPU FlashAttention backend.""" + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_ATTN" + + @staticmethod + def get_impl_cls() -> type["AttentionImpl"]: + return AscendAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> type["AscendMetadata"]: + return AscendMetadata + + @staticmethod + def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: + return AscendAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, int, int, int]: + """KV cache shape: (num_blocks, block_size, num_kv_heads, head_size). + + Key and value caches are allocated as two separate tensors with + this shape; they are paired in a ``(key_cache, value_cache)`` tuple. + """ + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: List[torch.Tensor], + dst_kv_cache: List[torch.Tensor], + src_to_dst: torch.Tensor, + ) -> None: + """Swap KV cache blocks between src and dst.""" + src_key_cache, src_value_cache = src_kv_cache + dst_key_cache, dst_value_cache = dst_kv_cache + + for src_idx, dst_idx in src_to_dst.tolist(): + dst_key_cache[dst_idx].copy_(src_key_cache[src_idx]) + dst_value_cache[dst_idx].copy_(src_value_cache[src_idx]) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dsts: torch.Tensor, + ) -> None: + """Copy KV cache blocks in-place.""" + key_caches = [kv[0] for kv in kv_caches] + value_caches = [kv[1] for kv in kv_caches] + + for src_idx, dst_idx in src_to_dsts.tolist(): + for key_cache in key_caches: + key_cache[dst_idx].copy_(key_cache[src_idx]) + for value_cache in value_caches: + value_cache[dst_idx].copy_(value_cache[src_idx]) + + +# ===================================================================== +# Metadata dataclass +# ===================================================================== + + +@dataclass +class AscendMetadata: + """Per-layer attention metadata for the Ascend backend.""" + + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + num_actual_tokens: int = 0 + + # Sequence lengths and query positions + seq_lens: Optional[torch.Tensor] = None # (batch,) + seq_lens_list: Optional[List[int]] = None + query_start_loc: Optional[torch.Tensor] = None # (batch+1,) + query_lens: Optional[torch.Tensor] = None + max_query_len: Optional[int] = None + + # KV cache mapping + block_tables: Optional[torch.Tensor] = None # (batch, max_blocks) + slot_mapping: Optional[torch.Tensor] = None # (num_tokens,) + + # Attention mask (for prefill causal masking) + attn_mask: Optional[torch.Tensor] = None + + +# ===================================================================== +# Metadata builder +# ===================================================================== + + +class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): + """Builds ``AscendMetadata`` from ``CommonAttentionMetadata``.""" + + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: "VllmConfig", + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size + self.num_kv_heads = kv_cache_spec.num_kv_heads + self.head_size = kv_cache_spec.head_size + + def reorder_batch( + self, + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + ) -> bool: + """ + Reorder so decodes (query_len == 1) come first, prefills after. + """ + from vllm.v1.attention.backends.utils import ( + reorder_batch_to_split_decodes_and_prefills, + ) + return reorder_batch_to_split_decodes_and_prefills( + input_batch, scheduler_output, decode_threshold=1 + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AscendMetadata: + """Build AscendMetadata from the common attention metadata.""" + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + # Determine attention state + if max_query_len == 1: + attn_state = AscendAttentionState.DecodeOnly + else: + attn_state = AscendAttentionState.ChunkedPrefill + + # Build cumulative sequence lengths for query (for prefill) + query_start_loc = common_attn_metadata.query_start_loc.to( + dtype=torch.int64 + ) + + seq_lens = common_attn_metadata.seq_lens + seq_lens_list = common_attn_metadata.seq_lens_cpu.tolist() + + # Build attention mask for prefill (causal mask) + attn_mask = None + if attn_state != AscendAttentionState.DecodeOnly: + max_seq = common_attn_metadata.max_seq_len + attn_mask = torch.ones( + max_seq, + max_seq, + dtype=torch.bool, + device=self.device, + ).triu_(diagonal=1) + + return AscendMetadata( + attn_state=attn_state, + num_actual_tokens=num_actual_tokens, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + query_start_loc=query_start_loc, + max_query_len=max_query_len, + block_tables=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + attn_mask=attn_mask, + ) + + def build_for_cudagraph_capture( + self, + common_attn_metadata: CommonAttentionMetadata, + ) -> AscendMetadata: + """Build metadata for graph capture (decode-only).""" + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + +# ===================================================================== +# Attention implementation +# ===================================================================== + + +class AscendAttentionBackendImpl(AttentionImpl): + """ + Ascend NPU attention kernel implementation. + + Uses ``torch_npu.npu_fusion_attention`` for prefill and + ``torch_npu.npu_incre_flash_attention`` for decode. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.hidden_size = self.num_heads * self.head_size + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + self.attn_type = attn_type + + if alibi_slopes is not None: + alibi_slopes = torch.tensor( + alibi_slopes, dtype=torch.float32, device="npu" + ) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + # Cached references to the KV cache tensors + self._key_cache: Optional[torch.Tensor] = None + self._value_cache: Optional[torch.Tensor] = None + + def forward( + self, + layer: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, ...], + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + + Args: + query: (num_tokens, num_heads * head_size) + key: (num_tokens, num_kv_heads * head_size) + value: (num_tokens, num_kv_heads * head_size) + kv_cache: (key_cache, value_cache) each + (num_blocks, block_size, num_kv_heads, head_size) + attn_metadata: AscendMetadata for this forward call. + + Returns: + (num_tokens, num_heads * head_size) + """ + import torch_npu # noqa: F401 + + num_tokens = query.shape[0] + + if output is None: + output = torch.empty( + num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device, + ) + + if attn_metadata is None: + return output.view(num_tokens, self.hidden_size).fill_(0) + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Reshape Q/K/V to BSH (tokens, heads, head_dim) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size).contiguous() + + # ---------------------------------------------------------- + # Step 1: Update KV cache + # ---------------------------------------------------------- + if len(kv_cache) > 1: + if self._key_cache is None: + self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] + + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache( + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + key_cache=self._key_cache, + value_cache=self._value_cache, + slot_indices=slots, + ) + + # ---------------------------------------------------------- + # Step 2: Compute attention + # ---------------------------------------------------------- + if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + output = self._forward_decode( + query, attn_metadata, output, num_tokens + ) + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens + ) + else: + # ChunkedPrefill or PrefillCacheHit + output = self._forward_chunked_prefill( + query, key, value, attn_metadata, output, num_tokens + ) + + return output.view(num_tokens, self.hidden_size) + + # ----------------------------------------------------------------- + # Decode path — paged attention via npu_incre_flash_attention + # ----------------------------------------------------------------- + + def _forward_decode( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + num_tokens: int, + ) -> torch.Tensor: + """Decode-only attention using incremental flash attention.""" + import torch_npu # noqa: F401 + + # npu_incre_flash_attention expects: + # query: (batch, 1, num_heads, head_size) + # key_cache: (num_blocks, block_size, num_kv_heads, head_size) + # value_cache: (num_blocks, block_size, num_kv_heads, head_size) + q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D) + + attn_out = torch_npu.npu_incre_flash_attention( + q, + self._key_cache, + self._value_cache, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + actual_seq_lengths=attn_metadata.seq_lens_list, + block_size=self._key_cache.shape[1], + input_layout="BNSD", + ) + + output[:num_tokens] = attn_out.squeeze(1) + return output + + # ----------------------------------------------------------------- + # Prefill without KV cache (first token, no paging) + # ----------------------------------------------------------------- + + def _forward_prefill_no_cache( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + num_tokens: int, + ) -> torch.Tensor: + """Prefill attention without KV cache (self-attention).""" + import torch_npu # noqa: F401 + + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + + attn_out = torch_npu.npu_fusion_attention( + query[:num_tokens], + key[:num_tokens], + value[:num_tokens], + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=0, + atten_mask=attn_metadata.attn_mask, + pre_tockens=2147483647, + next_tockens=0, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + ) + + output[:num_tokens] = attn_out[0] + return output + + # ----------------------------------------------------------------- + # Chunked prefill — mixed prefill+decode + # ----------------------------------------------------------------- + + def _forward_chunked_prefill( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + num_tokens: int, + ) -> torch.Tensor: + """Chunked prefill using npu_fusion_attention with paged KV cache.""" + import torch_npu # noqa: F401 + + # Split batch into decodes and prefills based on query length + query_start_loc = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + + # Compute per-request query lengths + query_lens = query_start_loc[1:] - query_start_loc[:-1] + num_requests = len(query_lens) + + # Separate decode (query_len == 1) and prefill requests + decode_mask = query_lens == 1 + prefill_mask = ~decode_mask + num_decodes = decode_mask.sum().item() + + # Process decode tokens + if num_decodes > 0 and self._key_cache is not None: + decode_indices = torch.where(decode_mask)[0] + decode_query = query[query_start_loc[decode_indices]] + decode_block_tables = attn_metadata.block_tables[decode_indices] + decode_seq_lens = seq_lens[decode_indices].tolist() + + decode_q = decode_query.unsqueeze(1) # (B_decode, 1, H, D) + + decode_out = torch_npu.npu_incre_flash_attention( + decode_q, + self._key_cache, + self._value_cache, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + scale_value=self.scale, + block_table=decode_block_tables, + actual_seq_lengths=decode_seq_lens, + block_size=self._key_cache.shape[1], + input_layout="BNSD", + ) + + for i, idx in enumerate(decode_indices): + token_pos = query_start_loc[idx].item() + output[token_pos] = decode_out[i].squeeze(0) + + # Process prefill tokens + if prefill_mask.any(): + prefill_indices = torch.where(prefill_mask)[0] + for idx in prefill_indices: + start = query_start_loc[idx].item() + end = query_start_loc[idx + 1].item() + q_len = end - start + kv_len = seq_lens[idx].item() + + q = query[start:end] # (q_len, H, D) + + # Use npu_fusion_attention for this single prefill request + # Build a causal mask for this sequence + causal_mask = torch.ones( + kv_len, kv_len, dtype=torch.bool, device=query.device + ).triu_(diagonal=1) + + # For chunked prefill, key/value come from the cache + if self._key_cache is not None and kv_len > q_len: + # Gather KV from paged cache for this request + block_table = attn_metadata.block_tables[idx] + num_blocks_needed = (kv_len + self._key_cache.shape[1] - 1) \ + // self._key_cache.shape[1] + block_ids = block_table[:num_blocks_needed] + + # Gather KV from block cache + gathered_k = self._key_cache[block_ids].reshape( + -1, self.num_kv_heads, self.head_size + )[:kv_len] + gathered_v = self._value_cache[block_ids].reshape( + -1, self.num_kv_heads, self.head_size + )[:kv_len] + + # Only last q_len rows of the mask + causal_mask = causal_mask[kv_len - q_len : kv_len, :kv_len] + + attn_out = torch_npu.npu_fusion_attention( + q.unsqueeze(0), # (1, q_len, H, D) — BSH layout + gathered_k.unsqueeze(0), + gathered_v.unsqueeze(0), + head_num=self.num_heads, + input_layout="BSND", + scale=self.scale, + sparse_mode=0, + atten_mask=causal_mask.unsqueeze(0), + pre_tockens=kv_len, + next_tockens=0, + ) + output[start:end] = attn_out[0].squeeze(0) + else: + # Full self-attention (no prior cache) + k = key[start:end] + v = value[start:end] + causal_mask = causal_mask[:q_len, :q_len] + + attn_out = torch_npu.npu_fusion_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + head_num=self.num_heads, + input_layout="BSND", + scale=self.scale, + sparse_mode=0, + atten_mask=causal_mask.unsqueeze(0), + pre_tockens=q_len, + next_tockens=0, + ) + output[start:end] = attn_out[0].squeeze(0) + + return output diff --git a/vllm_npu/distributed/__init__.py b/vllm_npu/distributed/__init__.py new file mode 100644 index 0000000..61712bf --- /dev/null +++ b/vllm_npu/distributed/__init__.py @@ -0,0 +1 @@ +"""Ascend NPU distributed communication (HCCL).""" diff --git a/vllm_npu/distributed/communicator.py b/vllm_npu/distributed/communicator.py new file mode 100644 index 0000000..627f327 --- /dev/null +++ b/vllm_npu/distributed/communicator.py @@ -0,0 +1,76 @@ +""" +NPUCommunicator — HCCL-based device communicator for Ascend NPU. + +Extends ``DeviceCommunicatorBase`` with NPU-specific collective +operations using the HCCL backend. +""" + +from typing import List, Optional + +import torch +import torch.distributed as dist +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase, +) + + +class NPUCommunicator(DeviceCommunicatorBase): + """Device communicator for Ascend NPU using HCCL.""" + + def __init__( + self, + cpu_group: dist.ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[dist.ProcessGroup] = None, + unique_name: str = "", + ): + super().__init__(cpu_group, device, device_group, unique_name) + import torch_npu # noqa: F401 + self.device = torch.npu.current_device() + + def all_to_all( + self, + input_: torch.Tensor, + scatter_dim: int = 0, + gather_dim: int = -1, + scatter_sizes: Optional[List[int]] = None, + gather_sizes: Optional[List[int]] = None, + ) -> torch.Tensor: + """All-to-all communication for NPU tensors.""" + if scatter_dim < 0: + scatter_dim += input_.dim() + if gather_dim < 0: + gather_dim += input_.dim() + + if scatter_sizes is not None and gather_sizes is not None: + input_list = [ + t.contiguous() + for t in torch.split(input_, scatter_sizes, scatter_dim) + ] + output_list = [] + tensor_shape_base = input_list[self.rank].size() + for i in range(self.world_size): + tensor_shape = list(tensor_shape_base) + tensor_shape[gather_dim] = gather_sizes[i] + output_list.append( + torch.empty( + tensor_shape, + dtype=input_.dtype, + device=input_.device, + ) + ) + else: + input_list = [ + t.contiguous() + for t in torch.tensor_split( + input_, self.world_size, scatter_dim + ) + ] + output_list = [ + torch.empty_like(input_list[i]) + for i in range(self.world_size) + ] + + dist.all_to_all(output_list, input_list, group=self.device_group) + output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() + return output_tensor diff --git a/vllm_npu/ops/__init__.py b/vllm_npu/ops/__init__.py new file mode 100644 index 0000000..72bb3b3 --- /dev/null +++ b/vllm_npu/ops/__init__.py @@ -0,0 +1 @@ +"""Ascend NPU custom op registrations.""" diff --git a/vllm_npu/ops/activation.py b/vllm_npu/ops/activation.py new file mode 100644 index 0000000..d8e8ecb --- /dev/null +++ b/vllm_npu/ops/activation.py @@ -0,0 +1,17 @@ +""" +NPU-optimized activation functions for Ascend. + +Provides ``AscendSiluAndMul`` that uses ``torch_npu.npu_swiglu`` for +fused SiLU+Mul on NPU devices. +""" + +import torch +from vllm.model_executor.layers.activation import SiluAndMul + + +class AscendSiluAndMul(SiluAndMul): + """SiluAndMul using torch_npu.npu_swiglu on Ascend NPU.""" + + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: + import torch_npu # noqa: F401 + return torch_npu.npu_swiglu(x) diff --git a/vllm_npu/ops/layernorm.py b/vllm_npu/ops/layernorm.py new file mode 100644 index 0000000..a973517 --- /dev/null +++ b/vllm_npu/ops/layernorm.py @@ -0,0 +1,41 @@ +""" +NPU-optimized layer normalization for Ascend. + +Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and +``torch_npu.npu_add_rms_norm``. +""" + +import torch + + +def rms_norm_npu( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> None: + """RMS norm using Ascend NPU fused kernel. + + Writes the result into ``out`` in-place. + """ + import torch_npu # noqa: F401 + normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon) + out.copy_(normed) + + +def fused_add_rms_norm_npu( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> None: + """Fused add + RMS norm using Ascend NPU kernel. + + Modifies ``input`` and ``residual`` in-place. + """ + import torch_npu # noqa: F401 + normed, residual_out = torch_npu.npu_add_rms_norm( + input, residual, weight, epsilon + ) + input.copy_(normed) + residual.copy_(residual_out) diff --git a/vllm_npu/ops/rotary_embedding.py b/vllm_npu/ops/rotary_embedding.py new file mode 100644 index 0000000..37f4e60 --- /dev/null +++ b/vllm_npu/ops/rotary_embedding.py @@ -0,0 +1,31 @@ +""" +NPU-optimized rotary embedding for Ascend. + +Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application. +""" + +import torch + + +def rotary_embedding_npu( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + """Apply rotary position embedding using Ascend NPU fused kernel. + + Modifies ``query`` and ``key`` in-place. + """ + import torch_npu # noqa: F401 + + if not query.is_contiguous(): + query = query.contiguous() + if not key.is_contiguous(): + key = key.contiguous() + + torch_npu._npu_rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) diff --git a/vllm_npu/platform.py b/vllm_npu/platform.py new file mode 100644 index 0000000..e5e128e --- /dev/null +++ b/vllm_npu/platform.py @@ -0,0 +1,217 @@ +""" +NPUPlatform — Ascend NPU platform implementation for vLLM. + +Implements the ``vllm.platforms.Platform`` interface so that vLLM can +transparently target Huawei Ascend NPU devices. +""" + +import gc +import os +from datetime import timedelta +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import PrefixStore +from vllm.logger import init_logger +from vllm.platforms import Platform, PlatformEnum + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + from vllm.utils import FlexibleArgumentParser +else: + ModelConfig = None + VllmConfig = None + FlexibleArgumentParser = None + +logger = init_logger(__name__) + + +class NPUPlatform(Platform): + """Out-of-tree platform for Huawei Ascend NPU.""" + + _enum = PlatformEnum.OOT + device_name: str = "npu" + device_type: str = "npu" + dispatch_key: str = "PrivateUse1" + ray_device_key: str = "NPU" + device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" + simple_compile_backend: str = "eager" # torch.compile not supported + + # ----------------------------------------------------------------- + # Device management + # ----------------------------------------------------------------- + + @classmethod + def get_device_capability(cls, device_id: int = 0): + return None + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + import torch_npu # noqa: F401 + return torch.npu.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + import torch_npu # noqa: F401 + _, total = torch.npu.mem_get_info(device_id) + return total + + @classmethod + def inference_mode(cls): + return torch.inference_mode() + + @classmethod + def set_device(cls, device: torch.device): + import torch_npu # noqa: F401 + torch.npu.set_device(device) + + @classmethod + def empty_cache(cls): + import torch_npu # noqa: F401 + torch.npu.empty_cache() + + @classmethod + def synchronize(cls): + import torch_npu # noqa: F401 + torch.npu.synchronize() + + @classmethod + def mem_get_info(cls) -> Tuple[int, int]: + import torch_npu # noqa: F401 + return torch.npu.mem_get_info() + + @classmethod + def is_pin_memory_available(cls): + return True + + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + + @classmethod + def get_current_memory_usage( + cls, + device: Optional[torch.types.Device] = None, + ) -> float: + import torch_npu # noqa: F401 + torch.npu.reset_peak_memory_stats(device) + return torch.npu.max_memory_allocated(device) + + @classmethod + def clear_npu_memory(cls): + import torch_npu # noqa: F401 + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + def is_sleep_mode_available(self) -> bool: + return False + + # ----------------------------------------------------------------- + # Attention backend routing + # ----------------------------------------------------------------- + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink=False, + use_sparse=False, + ): + return "vllm_npu.attention.attention_v1.AscendAttentionBackend" + + # ----------------------------------------------------------------- + # Distributed + # ----------------------------------------------------------------- + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm_npu.distributed.communicator.NPUCommunicator" + + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + """Create an HCCL-based process group for NPU distributed.""" + from torch.distributed import is_hccl_available + from torch_npu._C._distributed_c10d import ProcessGroupHCCL + + assert is_hccl_available(), ( + "HCCL is not available. Make sure torch_npu is properly installed." + ) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + + backend_options = ProcessGroupHCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupHCCL( + prefix_store, group_rank, group_size, backend_options + ) + device = torch.device("npu") + backend_class._set_sequence_number_for_group() + backend_type = ProcessGroup.BackendType.CUSTOM + + pg._register_backend(device, backend_type, backend_class) + return pg + + # ----------------------------------------------------------------- + # Configuration + # ----------------------------------------------------------------- + + @classmethod + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """Adapt vLLM configuration for NPU hardware.""" + from vllm.config import CompilationLevel + + parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + + # Set worker class + if parallel_config and parallel_config.worker_cls == "auto": + parallel_config.worker_cls = ( + "vllm_npu.worker.worker_v1.NPUWorker" + ) + + # Set default block size for NPU (aligned to 128) + if cache_config and cache_config.block_size is None: + cache_config.block_size = 128 + + # Disable torch.compile on NPU — use eager mode + compilation_config.level = CompilationLevel.NO_COMPILATION + + logger.info( + "NPUPlatform: configuration updated — " + "worker_cls=%s, block_size=%s, compilation=NO_COMPILATION", + getattr(parallel_config, "worker_cls", "N/A"), + getattr(cache_config, "block_size", "N/A"), + ) + + @classmethod + def supports_v1(cls, model_config: "ModelConfig") -> bool: + return True + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return False + + @classmethod + def support_static_graph_mode(cls) -> bool: + return False diff --git a/vllm_npu/worker/__init__.py b/vllm_npu/worker/__init__.py new file mode 100644 index 0000000..b2dcbf3 --- /dev/null +++ b/vllm_npu/worker/__init__.py @@ -0,0 +1 @@ +"""Ascend NPU worker implementation.""" diff --git a/vllm_npu/worker/worker_v1.py b/vllm_npu/worker/worker_v1.py new file mode 100644 index 0000000..82810e3 --- /dev/null +++ b/vllm_npu/worker/worker_v1.py @@ -0,0 +1,230 @@ +""" +NPUWorker — Ascend NPU worker for vLLM v1. + +Extends the GPU Worker to run on Ascend NPU devices, replacing CUDA +APIs with ``torch.npu`` / ``torch_npu`` equivalents for device +management, memory profiling, and distributed initialization. +""" + +import gc +import os +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.distributed import init_distributed_environment +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform +from vllm.utils import GiB_bytes, STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class NPUWorker(WorkerBase): + """Worker running on Ascend NPU devices.""" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + **kwargs, + ): + super().__init__( + vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + + if self.model_config.trust_remote_code: + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Determine cache dtype + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype + ] + + self.profiler = None + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + + # ----------------------------------------------------------------- + # Device initialization + # ----------------------------------------------------------------- + + def init_device(self) -> None: + """Initialize the NPU device and distributed environment.""" + import torch_npu # noqa: F401 + + os.environ.pop("HCCL_ASYNC_ERROR_HANDLING", None) + + self.device = torch.device(f"npu:{self.local_rank}") + current_platform.set_device(self.device) + current_platform.empty_cache() + + # Record initial memory + self.init_npu_memory, self.total_npu_memory = ( + current_platform.mem_get_info() + ) + + # Initialize distributed (HCCL) + init_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + "hccl", + ) + + # Set random seed + current_platform.seed_everything(self.model_config.seed) + + # NPU memory snapshot + self.requested_memory = ( + self.total_npu_memory * self.cache_config.gpu_memory_utilization + ) + + # Construct model runner + self.model_runner: GPUModelRunner = GPUModelRunner( + self.vllm_config, self.device + ) + + # ----------------------------------------------------------------- + # Memory profiling + # ----------------------------------------------------------------- + + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profile peak memory and return available KV cache memory.""" + import torch_npu # noqa: F401 + + GiB = lambda b: round(b / GiB_bytes, 2) + + current_platform.empty_cache() + gc.collect() + + # Execute a forward pass with dummy inputs to profile memory + self.model_runner.profile_run() + + # Check peak memory + free_npu_memory, _ = current_platform.mem_get_info() + + assert self.init_npu_memory > free_npu_memory, ( + "Error in memory profiling. " + f"Initial free memory {GiB(self.init_npu_memory)} GiB, " + f"current free memory {GiB(free_npu_memory)} GiB." + ) + + # Get peak memory from torch_npu stats + peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] + + current_platform.empty_cache() + torch_allocated = torch_npu.npu.memory_stats()[ + "allocated_bytes.all.current" + ] + total_allocated = ( + self.total_npu_memory - torch_npu.npu.mem_get_info()[0] + ) + non_torch = total_allocated - torch_allocated + if non_torch > 0: + peak_memory += non_torch + + available_kv_cache_memory = int( + self.total_npu_memory * self.cache_config.gpu_memory_utilization + - peak_memory + ) + available_kv_cache_memory = max(available_kv_cache_memory, 0) + + logger.info( + "Available KV cache memory: %.2f GiB (total: %.2f GiB)", + GiB(available_kv_cache_memory), + GiB(self.total_npu_memory), + ) + + gc.collect() + return available_kv_cache_memory + + # ----------------------------------------------------------------- + # Model lifecycle + # ----------------------------------------------------------------- + + def load_model(self) -> None: + self.model_runner.load_model() + + def get_model(self): + return self.model_runner.get_model() + + def get_kv_cache_spec(self) -> KVCacheSpec: + return self.model_runner.get_kv_cache_spec() + + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate KV caches on NPU.""" + self.model_runner.initialize_kv_cache(kv_cache_config) + + def compile_or_warm_up_model(self) -> None: + """Warm up the model (no torch.compile on NPU).""" + self.model_runner.capture_model() + + # ----------------------------------------------------------------- + # Execution + # ----------------------------------------------------------------- + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + output = self.model_runner.execute_model(scheduler_output) + return output if self.is_driver_worker else None + + def execute_dummy_batch(self) -> None: + self.model_runner.execute_dummy_batch() + + # ----------------------------------------------------------------- + # Misc + # ----------------------------------------------------------------- + + def sleep(self, level: int = 1) -> None: + pass + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + pass + + def get_supported_tasks(self): + return self.model_runner.get_supported_tasks() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> set: + return self.model_runner.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def profile(self, is_start: bool = True) -> None: + pass + + def take_draft_token_ids(self): + return self.model_runner.take_draft_token_ids() + + def check_health(self) -> None: + pass