mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations.
This commit is contained in:
@@ -15,4 +15,25 @@ def register():
|
|||||||
from vllm_npu.cuda_compat import _patch_cuda_to_npu
|
from vllm_npu.cuda_compat import _patch_cuda_to_npu
|
||||||
_patch_cuda_to_npu()
|
_patch_cuda_to_npu()
|
||||||
|
|
||||||
|
# Register NPU custom ops with vLLM's CustomOp dispatch so that
|
||||||
|
# ops like SiluAndMul, RMSNorm, RotaryEmbedding use NPU kernels
|
||||||
|
# instead of falling back to CUDA (which would produce garbage).
|
||||||
|
_register_npu_ops()
|
||||||
|
|
||||||
return "vllm_npu.platform.NPUPlatform"
|
return "vllm_npu.platform.NPUPlatform"
|
||||||
|
|
||||||
|
|
||||||
|
def _register_npu_ops():
|
||||||
|
"""Register Ascend NPU op overrides with vLLM's CustomOp system."""
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
from vllm_npu.ops.activation import AscendSiluAndMul
|
||||||
|
from vllm_npu.ops.layernorm import AscendRMSNorm
|
||||||
|
from vllm_npu.ops.rotary_embedding import AscendRotaryEmbedding
|
||||||
|
|
||||||
|
for name, op_cls in {
|
||||||
|
"SiluAndMul": AscendSiluAndMul,
|
||||||
|
"RMSNorm": AscendRMSNorm,
|
||||||
|
"RotaryEmbedding": AscendRotaryEmbedding,
|
||||||
|
}.items():
|
||||||
|
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
||||||
|
|||||||
@@ -380,6 +380,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
# TODO: Remove this contiguous in the future.
|
||||||
|
value = value.contiguous()
|
||||||
|
|
||||||
# Step 1: Update KV cache
|
# Step 1: Update KV cache
|
||||||
if key is not None and value is not None:
|
if key is not None and value is not None:
|
||||||
@@ -467,14 +469,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
k = key[start:end].unsqueeze(0)
|
k = key[start:end].unsqueeze(0)
|
||||||
v = value[start:end].unsqueeze(0)
|
v = value[start:end].unsqueeze(0)
|
||||||
|
|
||||||
# Create boolean mask (Lower triangle=True means Keep, Upper=False means Mask)
|
# npu_fusion_attention: True = mask out (do NOT attend)
|
||||||
# npu_fusion_attention (sparse_mode=0) interprets True as Keep?
|
# Upper triangle = future tokens = should be masked out
|
||||||
# Or if True=Mask, then tril masks Past (Garbage).
|
|
||||||
# But triu (Upper=True) produced Garbage.
|
|
||||||
# So we try tril (Lower=True).
|
|
||||||
attn_mask = torch.ones(
|
attn_mask = torch.ones(
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
q_len, q_len, dtype=torch.bool, device=query.device
|
||||||
).tril_(diagonal=0).unsqueeze(0).unsqueeze(0)
|
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
# Run npu_fusion_attention (BSND)
|
# Run npu_fusion_attention (BSND)
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
@@ -567,9 +566,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
-1, self.num_kv_heads, self.head_size
|
-1, self.num_kv_heads, self.head_size
|
||||||
)[:kv_len]
|
)[:kv_len]
|
||||||
|
|
||||||
|
# npu_fusion_attention: True = mask out
|
||||||
|
# For chunked prefill, mask future positions
|
||||||
causal_mask = torch.ones(
|
causal_mask = torch.ones(
|
||||||
q_len, kv_len, dtype=torch.bool, device=query.device
|
q_len, kv_len, dtype=torch.bool, device=query.device
|
||||||
).tril_(diagonal=kv_len - q_len) # Adjusted for offset? Or just simple?
|
).triu_(diagonal=kv_len - q_len + 1)
|
||||||
# logic for chunked prefill mask (non-square)?
|
# logic for chunked prefill mask (non-square)?
|
||||||
# If q_len < kv_len (prefill extension), mask logic is complex.
|
# If q_len < kv_len (prefill extension), mask logic is complex.
|
||||||
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
|
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
|
||||||
@@ -594,9 +595,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
k = key[start:end]
|
k = key[start:end]
|
||||||
v = value[start:end]
|
v = value[start:end]
|
||||||
|
|
||||||
|
# npu_fusion_attention: True = mask out
|
||||||
causal_mask = torch.ones(
|
causal_mask = torch.ones(
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
q_len, q_len, dtype=torch.bool, device=query.device
|
||||||
).tril_(diagonal=0)
|
).triu_(diagonal=1)
|
||||||
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
q.unsqueeze(0),
|
q.unsqueeze(0),
|
||||||
|
|||||||
@@ -1,41 +1,36 @@
|
|||||||
"""
|
"""
|
||||||
NPU-optimized layer normalization for Ascend.
|
NPU-optimized layer normalization for Ascend.
|
||||||
|
|
||||||
Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and
|
Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
|
||||||
``torch_npu.npu_add_rms_norm``.
|
``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
||||||
|
to NPU kernels automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_npu(
|
class AscendRMSNorm(RMSNorm):
|
||||||
out: torch.Tensor,
|
"""RMSNorm using Ascend NPU fused kernels.
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
epsilon: float,
|
|
||||||
) -> None:
|
|
||||||
"""RMS norm using Ascend NPU fused kernel.
|
|
||||||
|
|
||||||
Writes the result into ``out`` in-place.
|
Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
|
||||||
|
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
|
||||||
"""
|
"""
|
||||||
import torch_npu # noqa: F401
|
|
||||||
normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon)
|
|
||||||
out.copy_(normed)
|
|
||||||
|
|
||||||
|
def forward_oot(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
def fused_add_rms_norm_npu(
|
if residual is not None:
|
||||||
input: torch.Tensor,
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||||
residual: torch.Tensor,
|
x, residual, self.weight, self.variance_epsilon
|
||||||
weight: torch.Tensor,
|
)
|
||||||
epsilon: float,
|
return x, residual
|
||||||
) -> None:
|
|
||||||
"""Fused add + RMS norm using Ascend NPU kernel.
|
|
||||||
|
|
||||||
Modifies ``input`` and ``residual`` in-place.
|
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||||
"""
|
return x
|
||||||
import torch_npu # noqa: F401
|
|
||||||
normed, residual_out = torch_npu.npu_add_rms_norm(
|
|
||||||
input, residual, weight, epsilon
|
|
||||||
)
|
|
||||||
input.copy_(normed)
|
|
||||||
residual.copy_(residual_out)
|
|
||||||
|
|||||||
@@ -1,31 +1,77 @@
|
|||||||
"""
|
"""
|
||||||
NPU-optimized rotary embedding for Ascend.
|
NPU-optimized rotary embedding for Ascend.
|
||||||
|
|
||||||
Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application.
|
Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass
|
||||||
|
with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
||||||
|
to the NPU fused kernel automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
def rotary_embedding_npu(
|
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||||
positions: torch.Tensor,
|
"""RotaryEmbedding using Ascend NPU fused kernel.
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
head_size: int,
|
|
||||||
cos_sin_cache: torch.Tensor,
|
|
||||||
is_neox: bool,
|
|
||||||
) -> None:
|
|
||||||
"""Apply rotary position embedding using Ascend NPU fused kernel.
|
|
||||||
|
|
||||||
Modifies ``query`` and ``key`` in-place.
|
Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application.
|
||||||
"""
|
"""
|
||||||
import torch_npu # noqa: F401
|
|
||||||
|
|
||||||
if not query.is_contiguous():
|
def forward_oot(
|
||||||
query = query.contiguous()
|
self,
|
||||||
if not key.is_contiguous():
|
positions: torch.Tensor,
|
||||||
key = key.contiguous()
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
torch_npu._npu_rotary_embedding(
|
query_shape, key_shape = query.shape, key.shape
|
||||||
positions, query, key, head_size, cos_sin_cache, is_neox
|
|
||||||
)
|
if self.cos_sin_cache.device != query.device:
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||||
|
if self.cos_sin_cache.dtype != query.dtype:
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||||
|
|
||||||
|
if offsets is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Batched rotary embedding is currently not supported on NPU."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.rotary_dim < self.head_size:
|
||||||
|
# Partial rotary embedding: only rotate first rotary_dim dims
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
|
||||||
|
q_rot = query[..., :self.rotary_dim]
|
||||||
|
q_pass = query[..., self.rotary_dim:]
|
||||||
|
k_rot = key[..., :self.rotary_dim]
|
||||||
|
k_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
|
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||||
|
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||||
|
|
||||||
|
torch_npu._npu_rotary_embedding(
|
||||||
|
positions, q_rot, k_rot,
|
||||||
|
self.head_size, self.cos_sin_cache, self.is_neox_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||||
|
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||||
|
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||||
|
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||||
|
return q, k
|
||||||
|
else:
|
||||||
|
# Full rotary embedding
|
||||||
|
# TODO: Remove the contiguous in the future.
|
||||||
|
query = query.contiguous().view(query.shape[0], -1)
|
||||||
|
key = key.contiguous().view(key.shape[0], -1)
|
||||||
|
|
||||||
|
torch_npu._npu_rotary_embedding(
|
||||||
|
positions, query, key,
|
||||||
|
self.head_size, self.cos_sin_cache, self.is_neox_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.view(query_shape), key.view(key_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user