mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
- NPUPlatform: device management, HCCL process group, config adaptation - AscendAttentionBackend: npu_fusion_attention (prefill) + npu_incre_flash_attention (decode) - NPUCommunicator: HCCL-based distributed communication - NPUWorker: NPU device init, memory profiling - Custom ops: SiluAndMul, RMS norm, rotary embedding - Plugin registered via vllm.platform_plugins entry point Based on vllm-ascend official pattern, targeting Ascend 910B
32 lines
725 B
Python
32 lines
725 B
Python
"""
|
|
NPU-optimized rotary embedding for Ascend.
|
|
|
|
Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application.
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def rotary_embedding_npu(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
head_size: int,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
) -> None:
|
|
"""Apply rotary position embedding using Ascend NPU fused kernel.
|
|
|
|
Modifies ``query`` and ``key`` in-place.
|
|
"""
|
|
import torch_npu # noqa: F401
|
|
|
|
if not query.is_contiguous():
|
|
query = query.contiguous()
|
|
if not key.is_contiguous():
|
|
key = key.contiguous()
|
|
|
|
torch_npu._npu_rotary_embedding(
|
|
positions, query, key, head_size, cos_sin_cache, is_neox
|
|
)
|