Files
vllm-npu-plugin/vllm_npu/ops/layernorm.py

37 lines
1.0 KiB
Python

"""
NPU-optimized layer normalization for Ascend.
Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
to NPU kernels automatically.
"""
from typing import Optional, Tuple, Union
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
class AscendRMSNorm(RMSNorm):
"""RMSNorm using Ascend NPU fused kernels.
Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
"""
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu # noqa: F401
if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon
)
return x, residual
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
return x