feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations.

This commit is contained in:
2026-02-10 21:56:45 +08:00
parent 3aebca03d9
commit 4ca9d52cf2
4 changed files with 119 additions and 55 deletions

View File

@@ -15,4 +15,25 @@ def register():
from vllm_npu.cuda_compat import _patch_cuda_to_npu from vllm_npu.cuda_compat import _patch_cuda_to_npu
_patch_cuda_to_npu() _patch_cuda_to_npu()
# Register NPU custom ops with vLLM's CustomOp dispatch so that
# ops like SiluAndMul, RMSNorm, RotaryEmbedding use NPU kernels
# instead of falling back to CUDA (which would produce garbage).
_register_npu_ops()
return "vllm_npu.platform.NPUPlatform" return "vllm_npu.platform.NPUPlatform"
def _register_npu_ops():
"""Register Ascend NPU op overrides with vLLM's CustomOp system."""
from vllm.model_executor.custom_op import CustomOp
from vllm_npu.ops.activation import AscendSiluAndMul
from vllm_npu.ops.layernorm import AscendRMSNorm
from vllm_npu.ops.rotary_embedding import AscendRotaryEmbedding
for name, op_cls in {
"SiluAndMul": AscendSiluAndMul,
"RMSNorm": AscendRMSNorm,
"RotaryEmbedding": AscendRotaryEmbedding,
}.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)

View File

@@ -380,6 +380,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
# Step 1: Update KV cache # Step 1: Update KV cache
if key is not None and value is not None: if key is not None and value is not None:
@@ -467,14 +469,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
k = key[start:end].unsqueeze(0) k = key[start:end].unsqueeze(0)
v = value[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0)
# Create boolean mask (Lower triangle=True means Keep, Upper=False means Mask) # npu_fusion_attention: True = mask out (do NOT attend)
# npu_fusion_attention (sparse_mode=0) interprets True as Keep? # Upper triangle = future tokens = should be masked out
# Or if True=Mask, then tril masks Past (Garbage).
# But triu (Upper=True) produced Garbage.
# So we try tril (Lower=True).
attn_mask = torch.ones( attn_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).tril_(diagonal=0).unsqueeze(0).unsqueeze(0) ).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
# Run npu_fusion_attention (BSND) # Run npu_fusion_attention (BSND)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
@@ -567,9 +566,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
-1, self.num_kv_heads, self.head_size -1, self.num_kv_heads, self.head_size
)[:kv_len] )[:kv_len]
# npu_fusion_attention: True = mask out
# For chunked prefill, mask future positions
causal_mask = torch.ones( causal_mask = torch.ones(
q_len, kv_len, dtype=torch.bool, device=query.device q_len, kv_len, dtype=torch.bool, device=query.device
).tril_(diagonal=kv_len - q_len) # Adjusted for offset? Or just simple? ).triu_(diagonal=kv_len - q_len + 1)
# logic for chunked prefill mask (non-square)? # logic for chunked prefill mask (non-square)?
# If q_len < kv_len (prefill extension), mask logic is complex. # If q_len < kv_len (prefill extension), mask logic is complex.
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len). # Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
@@ -594,9 +595,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
k = key[start:end] k = key[start:end]
v = value[start:end] v = value[start:end]
# npu_fusion_attention: True = mask out
causal_mask = torch.ones( causal_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).tril_(diagonal=0) ).triu_(diagonal=1)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0), q.unsqueeze(0),

View File

@@ -1,41 +1,36 @@
""" """
NPU-optimized layer normalization for Ascend. NPU-optimized layer normalization for Ascend.
Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
``torch_npu.npu_add_rms_norm``. ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
to NPU kernels automatically.
""" """
from typing import Optional, Tuple, Union
import torch import torch
from vllm.model_executor.layers.layernorm import RMSNorm
def rms_norm_npu( class AscendRMSNorm(RMSNorm):
out: torch.Tensor, """RMSNorm using Ascend NPU fused kernels.
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> None:
"""RMS norm using Ascend NPU fused kernel.
Writes the result into ``out`` in-place. Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
""" """
import torch_npu # noqa: F401
normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon)
out.copy_(normed)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu # noqa: F401
def fused_add_rms_norm_npu( if residual is not None:
input: torch.Tensor, x, _, residual = torch_npu.npu_add_rms_norm(
residual: torch.Tensor, x, residual, self.weight, self.variance_epsilon
weight: torch.Tensor, )
epsilon: float, return x, residual
) -> None:
"""Fused add + RMS norm using Ascend NPU kernel.
Modifies ``input`` and ``residual`` in-place. x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
""" return x
import torch_npu # noqa: F401
normed, residual_out = torch_npu.npu_add_rms_norm(
input, residual, weight, epsilon
)
input.copy_(normed)
residual.copy_(residual_out)

View File

@@ -1,31 +1,77 @@
""" """
NPU-optimized rotary embedding for Ascend. NPU-optimized rotary embedding for Ascend.
Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application. Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass
with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
to the NPU fused kernel automatically.
""" """
from typing import Optional, Tuple
import torch import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_npu( class AscendRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor, """RotaryEmbedding using Ascend NPU fused kernel.
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
"""Apply rotary position embedding using Ascend NPU fused kernel.
Modifies ``query`` and ``key`` in-place. Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application.
""" """
import torch_npu # noqa: F401
if not query.is_contiguous(): def forward_oot(
query = query.contiguous() self,
if not key.is_contiguous(): positions: torch.Tensor,
key = key.contiguous() query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu # noqa: F401
torch_npu._npu_rotary_embedding( query_shape, key_shape = query.shape, key.shape
positions, query, key, head_size, cos_sin_cache, is_neox
) if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU."
)
if self.rotary_dim < self.head_size:
# Partial rotary embedding: only rotate first rotary_dim dims
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
k_rot = key[..., :self.rotary_dim]
k_pass = key[..., self.rotary_dim:]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
torch_npu._npu_rotary_embedding(
positions, q_rot, k_rot,
self.head_size, self.cos_sin_cache, self.is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
else:
# Full rotary embedding
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions, query, key,
self.head_size, self.cos_sin_cache, self.is_neox_style,
)
return query.view(query_shape), key.view(key_shape)