""" NPU-optimized rotary embedding for Ascend. Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application. """ import torch def rotary_embedding_npu( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: """Apply rotary position embedding using Ascend NPU fused kernel. Modifies ``query`` and ``key`` in-place. """ import torch_npu # noqa: F401 if not query.is_contiguous(): query = query.contiguous() if not key.is_contiguous(): key = key.contiguous() torch_npu._npu_rotary_embedding( positions, query, key, head_size, cos_sin_cache, is_neox )