feat: initial vllm-npu-plugin for Ascend NPU adaptation

- NPUPlatform: device management, HCCL process group, config adaptation
- AscendAttentionBackend: npu_fusion_attention (prefill) + npu_incre_flash_attention (decode)
- NPUCommunicator: HCCL-based distributed communication
- NPUWorker: NPU device init, memory profiling
- Custom ops: SiluAndMul, RMS norm, rotary embedding
- Plugin registered via vllm.platform_plugins entry point

Based on vllm-ascend official pattern, targeting Ascend 910B
This commit is contained in:
2026-02-10 11:06:01 +08:00
commit e75504df72
15 changed files with 1344 additions and 0 deletions

1
vllm_npu/ops/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Ascend NPU custom op registrations."""

View File

@@ -0,0 +1,17 @@
"""
NPU-optimized activation functions for Ascend.
Provides ``AscendSiluAndMul`` that uses ``torch_npu.npu_swiglu`` for
fused SiLU+Mul on NPU devices.
"""
import torch
from vllm.model_executor.layers.activation import SiluAndMul
class AscendSiluAndMul(SiluAndMul):
"""SiluAndMul using torch_npu.npu_swiglu on Ascend NPU."""
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu # noqa: F401
return torch_npu.npu_swiglu(x)

41
vllm_npu/ops/layernorm.py Normal file
View File

@@ -0,0 +1,41 @@
"""
NPU-optimized layer normalization for Ascend.
Provides RMS norm operations using ``torch_npu.npu_rms_norm`` and
``torch_npu.npu_add_rms_norm``.
"""
import torch
def rms_norm_npu(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> None:
"""RMS norm using Ascend NPU fused kernel.
Writes the result into ``out`` in-place.
"""
import torch_npu # noqa: F401
normed, _residual = torch_npu.npu_rms_norm(input, weight, epsilon)
out.copy_(normed)
def fused_add_rms_norm_npu(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> None:
"""Fused add + RMS norm using Ascend NPU kernel.
Modifies ``input`` and ``residual`` in-place.
"""
import torch_npu # noqa: F401
normed, residual_out = torch_npu.npu_add_rms_norm(
input, residual, weight, epsilon
)
input.copy_(normed)
residual.copy_(residual_out)

View File

@@ -0,0 +1,31 @@
"""
NPU-optimized rotary embedding for Ascend.
Wraps ``torch_npu._npu_rotary_embedding`` for fused RoPE application.
"""
import torch
def rotary_embedding_npu(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
"""Apply rotary position embedding using Ascend NPU fused kernel.
Modifies ``query`` and ``key`` in-place.
"""
import torch_npu # noqa: F401
if not query.is_contiguous():
query = query.contiguous()
if not key.is_contiguous():
key = key.contiguous()
torch_npu._npu_rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox
)