feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations.

This commit is contained in:
2026-02-10 21:56:45 +08:00
parent 3aebca03d9
commit 4ca9d52cf2
4 changed files with 119 additions and 55 deletions

View File

@@ -1,41 +1,36 @@
"""
NPU-optimized layer normalization for Ascend.
Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and
``torch_npu.npu_add_rms_norm``.
Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
to NPU kernels automatically.
"""
from typing import Optional, Tuple, Union
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
def rms_norm_npu(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> None:
"""RMS norm using Ascend NPU fused kernel.
class AscendRMSNorm(RMSNorm):
"""RMSNorm using Ascend NPU fused kernels.
Writes the result into ``out`` in-place.
Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
"""
import torch_npu # noqa: F401
normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon)
out.copy_(normed)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu # noqa: F401
def fused_add_rms_norm_npu(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> None:
"""Fused add + RMS norm using Ascend NPU kernel.
if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon
)
return x, residual
Modifies ``input`` and ``residual`` in-place.
"""
import torch_npu # noqa: F401
normed, residual_out = torch_npu.npu_add_rms_norm(
input, residual, weight, epsilon
)
input.copy_(normed)
residual.copy_(residual_out)
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
return x