From 4ca9d52cf2ca7568d224a6f75c43cf4241acd027 Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 21:56:45 +0800 Subject: [PATCH] feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations. --- vllm_npu/__init__.py | 21 ++++++++ vllm_npu/attention/attention_v1.py | 18 ++++--- vllm_npu/ops/layernorm.py | 51 ++++++++---------- vllm_npu/ops/rotary_embedding.py | 84 +++++++++++++++++++++++------- 4 files changed, 119 insertions(+), 55 deletions(-) diff --git a/vllm_npu/__init__.py b/vllm_npu/__init__.py index 9ea89ec..fa5886f 100644 --- a/vllm_npu/__init__.py +++ b/vllm_npu/__init__.py @@ -15,4 +15,25 @@ def register(): from vllm_npu.cuda_compat import _patch_cuda_to_npu _patch_cuda_to_npu() + # Register NPU custom ops with vLLM's CustomOp dispatch so that + # ops like SiluAndMul, RMSNorm, RotaryEmbedding use NPU kernels + # instead of falling back to CUDA (which would produce garbage). + _register_npu_ops() + return "vllm_npu.platform.NPUPlatform" + + +def _register_npu_ops(): + """Register Ascend NPU op overrides with vLLM's CustomOp system.""" + from vllm.model_executor.custom_op import CustomOp + + from vllm_npu.ops.activation import AscendSiluAndMul + from vllm_npu.ops.layernorm import AscendRMSNorm + from vllm_npu.ops.rotary_embedding import AscendRotaryEmbedding + + for name, op_cls in { + "SiluAndMul": AscendSiluAndMul, + "RMSNorm": AscendRMSNorm, + "RotaryEmbedding": AscendRotaryEmbedding, + }.items(): + CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index dc5bcce..ad0b9b8 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -380,6 +380,8 @@ class AscendAttentionBackendImpl(AttentionImpl): query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + # TODO: Remove this contiguous in the future. + value = value.contiguous() # Step 1: Update KV cache if key is not None and value is not None: @@ -467,14 +469,11 @@ class AscendAttentionBackendImpl(AttentionImpl): k = key[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0) - # Create boolean mask (Lower triangle=True means Keep, Upper=False means Mask) - # npu_fusion_attention (sparse_mode=0) interprets True as Keep? - # Or if True=Mask, then tril masks Past (Garbage). - # But triu (Upper=True) produced Garbage. - # So we try tril (Lower=True). + # npu_fusion_attention: True = mask out (do NOT attend) + # Upper triangle = future tokens = should be masked out attn_mask = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device - ).tril_(diagonal=0).unsqueeze(0).unsqueeze(0) + ).triu_(diagonal=1).unsqueeze(0).unsqueeze(0) # Run npu_fusion_attention (BSND) attn_out = torch_npu.npu_fusion_attention( @@ -567,9 +566,11 @@ class AscendAttentionBackendImpl(AttentionImpl): -1, self.num_kv_heads, self.head_size )[:kv_len] + # npu_fusion_attention: True = mask out + # For chunked prefill, mask future positions causal_mask = torch.ones( q_len, kv_len, dtype=torch.bool, device=query.device - ).tril_(diagonal=kv_len - q_len) # Adjusted for offset? Or just simple? + ).triu_(diagonal=kv_len - q_len + 1) # logic for chunked prefill mask (non-square)? # If q_len < kv_len (prefill extension), mask logic is complex. # Usually: mask[i, j] = True if j <= i + (kv_len - q_len). @@ -594,9 +595,10 @@ class AscendAttentionBackendImpl(AttentionImpl): k = key[start:end] v = value[start:end] + # npu_fusion_attention: True = mask out causal_mask = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device - ).tril_(diagonal=0) + ).triu_(diagonal=1) attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0), diff --git a/vllm_npu/ops/layernorm.py b/vllm_npu/ops/layernorm.py index a973517..96b0cd9 100644 --- a/vllm_npu/ops/layernorm.py +++ b/vllm_npu/ops/layernorm.py @@ -1,41 +1,36 @@ """ NPU-optimized layer normalization for Ascend. -Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and -``torch_npu.npu_add_rms_norm``. +Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with +``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route +to NPU kernels automatically. """ +from typing import Optional, Tuple, Union + import torch +from vllm.model_executor.layers.layernorm import RMSNorm -def rms_norm_npu( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - epsilon: float, -) -> None: - """RMS norm using Ascend NPU fused kernel. +class AscendRMSNorm(RMSNorm): + """RMSNorm using Ascend NPU fused kernels. - Writes the result into ``out`` in-place. + Uses ``torch_npu.npu_rms_norm`` for standalone normalization and + ``torch_npu.npu_add_rms_norm`` for fused residual-add + norm. """ - import torch_npu # noqa: F401 - normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon) - out.copy_(normed) + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu # noqa: F401 -def fused_add_rms_norm_npu( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - epsilon: float, -) -> None: - """Fused add + RMS norm using Ascend NPU kernel. + if residual is not None: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon + ) + return x, residual - Modifies ``input`` and ``residual`` in-place. - """ - import torch_npu # noqa: F401 - normed, residual_out = torch_npu.npu_add_rms_norm( - input, residual, weight, epsilon - ) - input.copy_(normed) - residual.copy_(residual_out) + x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + return x diff --git a/vllm_npu/ops/rotary_embedding.py b/vllm_npu/ops/rotary_embedding.py index 37f4e60..b5797f3 100644 --- a/vllm_npu/ops/rotary_embedding.py +++ b/vllm_npu/ops/rotary_embedding.py @@ -1,31 +1,77 @@ """ NPU-optimized rotary embedding for Ascend. -Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application. +Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass +with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route +to the NPU fused kernel automatically. """ +from typing import Optional, Tuple + import torch +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -def rotary_embedding_npu( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - head_size: int, - cos_sin_cache: torch.Tensor, - is_neox: bool, -) -> None: - """Apply rotary position embedding using Ascend NPU fused kernel. +class AscendRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding using Ascend NPU fused kernel. - Modifies ``query`` and ``key`` in-place. + Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application. """ - import torch_npu # noqa: F401 - if not query.is_contiguous(): - query = query.contiguous() - if not key.is_contiguous(): - key = key.contiguous() + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + import torch_npu # noqa: F401 - torch_npu._npu_rotary_embedding( - positions, query, key, head_size, cos_sin_cache, is_neox - ) + query_shape, key_shape = query.shape, key.shape + + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU." + ) + + if self.rotary_dim < self.head_size: + # Partial rotary embedding: only rotate first rotary_dim dims + num_tokens = query.shape[0] + query = query.view(num_tokens, -1, self.head_size) + key = key.view(num_tokens, -1, self.head_size) + + q_rot = query[..., :self.rotary_dim] + q_pass = query[..., self.rotary_dim:] + k_rot = key[..., :self.rotary_dim] + k_pass = key[..., self.rotary_dim:] + + q_rot = q_rot.contiguous().view(num_tokens, -1) + k_rot = k_rot.contiguous().view(num_tokens, -1) + + torch_npu._npu_rotary_embedding( + positions, q_rot, k_rot, + self.head_size, self.cos_sin_cache, self.is_neox_style, + ) + + q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) + k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) + q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) + k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) + return q, k + else: + # Full rotary embedding + # TODO: Remove the contiguous in the future. + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(key.shape[0], -1) + + torch_npu._npu_rotary_embedding( + positions, query, key, + self.head_size, self.cos_sin_cache, self.is_neox_style, + ) + + return query.view(query_shape), key.view(key_shape)