mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
"""
|
|
NPU-optimized rotary embedding for Ascend.
|
|
|
|
Provides ``AscendRotaryEmbedding`` — a proper ``RotaryEmbedding`` subclass
|
|
with ``forward_oot()`` so that vLLM's ``CustomOp`` dispatch can route
|
|
to the NPU fused kernel automatically.
|
|
"""
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|
|
|
|
|
class AscendRotaryEmbedding(RotaryEmbedding):
|
|
"""RotaryEmbedding using Ascend NPU fused kernel.
|
|
|
|
Uses ``torch_npu._npu_rotary_embedding`` for in-place RoPE application.
|
|
"""
|
|
|
|
def forward_oot(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
offsets: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
import torch_npu # noqa: F401
|
|
|
|
query_shape, key_shape = query.shape, key.shape
|
|
|
|
if self.cos_sin_cache.device != query.device:
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
|
if self.cos_sin_cache.dtype != query.dtype:
|
|
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
|
|
|
if offsets is not None:
|
|
raise NotImplementedError(
|
|
"Batched rotary embedding is currently not supported on NPU."
|
|
)
|
|
|
|
if self.rotary_dim < self.head_size:
|
|
# Partial rotary embedding: only rotate first rotary_dim dims
|
|
num_tokens = query.shape[0]
|
|
query = query.view(num_tokens, -1, self.head_size)
|
|
key = key.view(num_tokens, -1, self.head_size)
|
|
|
|
q_rot = query[..., :self.rotary_dim]
|
|
q_pass = query[..., self.rotary_dim:]
|
|
k_rot = key[..., :self.rotary_dim]
|
|
k_pass = key[..., self.rotary_dim:]
|
|
|
|
q_rot = q_rot.contiguous().view(num_tokens, -1)
|
|
k_rot = k_rot.contiguous().view(num_tokens, -1)
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
positions, q_rot, k_rot,
|
|
self.head_size, self.cos_sin_cache, self.is_neox_style,
|
|
)
|
|
|
|
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
|
|
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
|
|
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
|
|
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
|
|
return q, k
|
|
else:
|
|
# Full rotary embedding
|
|
# TODO: Remove the contiguous in the future.
|
|
query = query.contiguous().view(query.shape[0], -1)
|
|
key = key.contiguous().view(key.shape[0], -1)
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
positions, query, key,
|
|
self.head_size, self.cos_sin_cache, self.is_neox_style,
|
|
)
|
|
|
|
return query.view(query_shape), key.view(key_shape)
|