""" NPU-optimized layer normalization for Ascend. Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and ``torch_npu.npu_add_rms_norm``. """ import torch def rms_norm_npu( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float, ) -> None: """RMS norm using Ascend NPU fused kernel. Writes the result into ``out`` in-place. """ import torch_npu # noqa: F401 normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon) out.copy_(normed) def fused_add_rms_norm_npu( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float, ) -> None: """Fused add + RMS norm using Ascend NPU kernel. Modifies ``input`` and ``residual`` in-place. """ import torch_npu # noqa: F401 normed, residual_out = torch_npu.npu_add_rms_norm( input, residual, weight, epsilon ) input.copy_(normed) residual.copy_(residual_out)