mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
37 lines
1.0 KiB
Python
37 lines
1.0 KiB
Python
"""
|
|
NPU-optimized layer normalization for Ascend.
|
|
|
|
Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with
|
|
``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
|
to NPU kernels automatically.
|
|
"""
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
|
|
|
|
class AscendRMSNorm(RMSNorm):
|
|
"""RMSNorm using Ascend NPU fused kernels.
|
|
|
|
Uses ``torch_npu.npu_rms_norm`` for standalone normalization and
|
|
``torch_npu.npu_add_rms_norm`` for fused residual-add + norm.
|
|
"""
|
|
|
|
def forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
import torch_npu # noqa: F401
|
|
|
|
if residual is not None:
|
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
|
x, residual, self.weight, self.variance_epsilon
|
|
)
|
|
return x, residual
|
|
|
|
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
|
return x
|