""" NPU-optimized rotary embedding for Ascend. Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route to the NPU fused kernel automatically. """ from typing import Optional, Tuple import torch from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding class AscendRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding using Ascend NPU fused kernel. Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application. """ def forward_oot( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: import torch_npu # noqa: F401 query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) if self.cos_sin_cache.dtype != query.dtype: self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) if offsets is not None: raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU." ) if self.rotary_dim < self.head_size: # Partial rotary embedding: only rotate first rotary_dim dims num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size) q_rot = query[..., :self.rotary_dim] q_pass = query[..., self.rotary_dim:] k_rot = key[..., :self.rotary_dim] k_pass = key[..., self.rotary_dim:] q_rot = q_rot.contiguous().view(num_tokens, -1) k_rot = k_rot.contiguous().view(num_tokens, -1) torch_npu._npu_rotary_embedding( positions, q_rot, k_rot, self.head_size, self.cos_sin_cache, self.is_neox_style, ) q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) return q, k else: # Full rotary embedding # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1) torch_npu._npu_rotary_embedding( positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style, ) return query.view(query_shape), key.view(key_shape)