mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations.
This commit is contained in:
@@ -1,41 +1,36 @@
|
||||
"""
|
||||
NPU-optimized layer normalization for Ascend.
|
||||
|
||||
Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and
|
||||
``torch_npu.npu_add_rms_norm``.
|
||||
Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
|
||||
``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
||||
to NPU kernels automatically.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
def rms_norm_npu(
|
||||
out: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
) -> None:
|
||||
"""RMS norm using Ascend NPU fused kernel.
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
"""RMSNorm using Ascend NPU fused kernels.
|
||||
|
||||
Writes the result into ``out`` in-place.
|
||||
Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
|
||||
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
|
||||
"""
|
||||
import torch_npu # noqa: F401
|
||||
normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon)
|
||||
out.copy_(normed)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
def fused_add_rms_norm_npu(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
) -> None:
|
||||
"""Fused add + RMS norm using Ascend NPU kernel.
|
||||
if residual is not None:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon
|
||||
)
|
||||
return x, residual
|
||||
|
||||
Modifies ``input`` and ``residual`` in-place.
|
||||
"""
|
||||
import torch_npu # noqa: F401
|
||||
normed, residual_out = torch_npu.npu_add_rms_norm(
|
||||
input, residual, weight, epsilon
|
||||
)
|
||||
input.copy_(normed)
|
||||
residual.copy_(residual_out)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x
|
||||
|
||||
@@ -1,31 +1,77 @@
|
||||
"""
|
||||
NPU-optimized rotary embedding for Ascend.
|
||||
|
||||
Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application.
|
||||
Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass
|
||||
with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
||||
to the NPU fused kernel automatically.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
def rotary_embedding_npu(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
"""Apply rotary position embedding using Ascend NPU fused kernel.
|
||||
class AscendRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding using Ascend NPU fused kernel.
|
||||
|
||||
Modifies ``query`` and ``key`` in-place.
|
||||
Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application.
|
||||
"""
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
if not query.is_contiguous():
|
||||
query = query.contiguous()
|
||||
if not key.is_contiguous():
|
||||
key = key.contiguous()
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox
|
||||
)
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU."
|
||||
)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
# Partial rotary embedding: only rotate first rotary_dim dims
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
k_rot = key[..., :self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim:]
|
||||
|
||||
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
||||
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
||||
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions, q_rot, k_rot,
|
||||
self.head_size, self.cos_sin_cache, self.is_neox_style,
|
||||
)
|
||||
|
||||
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
||||
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
||||
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
||||
return q, k
|
||||
else:
|
||||
# Full rotary embedding
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions, query, key,
|
||||
self.head_size, self.cos_sin_cache, self.is_neox_style,
|
||||
)
|
||||
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
Reference in New Issue
Block a user