""" NPU-optimized layer normalization for Ascend. Provides ``AscendRMSNorm`` — a proper ``RMSNorm`` subclass with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route to NPU kernels automatically. """ from typing import Optional, Tuple, Union import torch from vllm.model_executor.layers.layernorm import RMSNorm class AscendRMSNorm(RMSNorm): """RMSNorm using Ascend NPU fused kernels. Uses ``torch_npu.npu_rms_norm`` for standalone normalization and ``torch_npu.npu_add_rms_norm`` for fused residual-add + norm. """ def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu # noqa: F401 if residual is not None: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon ) return x, residual x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) return x