feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations.

This commit is contained in:
2026-02-10 21:56:45 +08:00
parent 3aebca03d9
commit 4ca9d52cf2
4 changed files with 119 additions and 55 deletions

View File

@@ -1,31 +1,77 @@
"""
NPU-optimized rotary embedding for Ascend.
Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application.
Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass
with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
to the NPU fused kernel automatically.
"""
from typing import Optional, Tuple
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_npu(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
"""Apply rotary position embedding using Ascend NPU fused kernel.
class AscendRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding using Ascend NPU fused kernel.
Modifies ``query`` and ``key`` in-place.
Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application.
"""
import torch_npu # noqa: F401
if not query.is_contiguous():
query = query.contiguous()
if not key.is_contiguous():
key = key.contiguous()
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
import torch_npu # noqa: F401
torch_npu._npu_rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox
)
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU."
)
if self.rotary_dim < self.head_size:
# Partial rotary embedding: only rotate first rotary_dim dims
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
k_rot = key[..., :self.rotary_dim]
k_pass = key[..., self.rotary_dim:]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
torch_npu._npu_rotary_embedding(
positions, q_rot, k_rot,
self.head_size, self.cos_sin_cache, self.is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
else:
# Full rotary embedding
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions, query, key,
self.head_size, self.cos_sin_cache, self.is_neox_style,
)
return query.view(query_shape), key.view(key_shape)