mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
feat: Add Ascend NPU attention backend for vLLM using FlashAttention operators.
This commit is contained in:
@@ -5,9 +5,9 @@ Implements the ``AttentionBackend``, ``AttentionMetadata``,
|
|||||||
``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using
|
``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using
|
||||||
Huawei Ascend NPU FlashAttention operators:
|
Huawei Ascend NPU FlashAttention operators:
|
||||||
|
|
||||||
- ``torch_npu.npu_fusion_attention`` — fused multi-head attention
|
- ``torch_npu._npu_flash_attention`` — prefill attention (TND layout)
|
||||||
- ``torch_npu._npu_reshape_and_cache`` — KV cache update
|
- ``torch_npu._npu_reshape_and_cache`` — KV cache update
|
||||||
- ``torch_npu.npu_incre_flash_attention`` — paged-attention decode
|
- ``torch_npu._npu_paged_attention`` — paged-attention decode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -319,30 +319,27 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
):
|
):
|
||||||
"""Update KV cache with new key/value tensors.
|
"""Update KV cache with new key/value tensors.
|
||||||
|
|
||||||
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
Uses ``torch_npu._npu_reshape_and_cache`` for efficient in-place
|
||||||
slot_mapping indices.
|
KV cache update, matching vllm-ascend reference.
|
||||||
"""
|
"""
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
if self._key_cache is None:
|
if self._key_cache is None:
|
||||||
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
key_to_cache = key[:attn_metadata.num_actual_tokens]
|
num_actual = attn_metadata.num_actual_tokens
|
||||||
val_to_cache = value[:attn_metadata.num_actual_tokens]
|
torch_npu._npu_reshape_and_cache(
|
||||||
|
key=key[:num_actual],
|
||||||
# Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env)
|
value=value[:num_actual],
|
||||||
block_size = self._key_cache.shape[1]
|
key_cache=self._key_cache,
|
||||||
block_idx = slots // block_size
|
value_cache=self._value_cache,
|
||||||
block_offset = slots % block_size
|
slot_indices=slots,
|
||||||
self._key_cache[block_idx, block_offset] = key_to_cache
|
)
|
||||||
self._value_cache[block_idx, block_offset] = val_to_cache
|
|
||||||
|
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
|
||||||
# Forward dispatch
|
|
||||||
# -----------------------------------------------------------------
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: nn.Module,
|
layer: nn.Module,
|
||||||
@@ -399,15 +396,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query, key, value, attn_metadata, output, num_tokens
|
query, key, value, attn_metadata, output, num_tokens
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ChunkedPrefill or PrefillCacheHit
|
# ChunkedPrefill — use npu_fused_infer_attention_score
|
||||||
output = self._forward_chunked_prefill(
|
output = self._forward_chunked_prefill(
|
||||||
query, key, value, attn_metadata, output, num_tokens
|
query, attn_metadata, output, num_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
# Decode path — paged attention (matches Huawei _npu_paged_attention)
|
# Decode path — paged attention via _npu_paged_attention
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
@@ -417,29 +414,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Decode-only via npu_incre_flash_attention."""
|
"""Decode-only via _npu_paged_attention (matches vllm-ascend)."""
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD
|
torch_npu._npu_paged_attention(
|
||||||
|
query=query,
|
||||||
attn_out = torch_npu.npu_incre_flash_attention(
|
key_cache=self._key_cache,
|
||||||
q,
|
value_cache=self._value_cache,
|
||||||
self._key_cache,
|
num_kv_heads=self.num_kv_heads,
|
||||||
self._value_cache,
|
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
scale_value=self.scale,
|
scale_value=self.scale,
|
||||||
block_table=attn_metadata.block_tables,
|
block_table=attn_metadata.block_tables,
|
||||||
actual_seq_lengths=attn_metadata.seq_lens_list,
|
context_lens=attn_metadata.seq_lens,
|
||||||
block_size=self._key_cache.shape[1],
|
out=output,
|
||||||
input_layout="BNSD",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output[:num_tokens] = attn_out.squeeze(2)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
# Prefill without KV cache
|
# Prefill without KV cache — _npu_flash_attention (TND layout)
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
|
|
||||||
def _forward_prefill_no_cache(
|
def _forward_prefill_no_cache(
|
||||||
@@ -451,168 +443,72 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Prefill attention without KV cache (self-attention) via per-req loop."""
|
"""Prefill attention without KV cache via _npu_flash_attention.
|
||||||
|
|
||||||
|
Uses TND layout and a pre-built causal mask from metadata.
|
||||||
|
This matches vllm-ascend's _forward_prefill_no_cache.
|
||||||
|
"""
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
query_start_loc = attn_metadata.query_start_loc
|
mask = attn_metadata.attn_mask
|
||||||
seq_lens = attn_metadata.seq_lens
|
|
||||||
num_reqs = len(seq_lens)
|
|
||||||
|
|
||||||
# Iterate and process each request independently to bypass TND issues
|
torch_npu._npu_flash_attention(
|
||||||
for i in range(num_reqs):
|
query=query,
|
||||||
start = query_start_loc[i].item()
|
key=key,
|
||||||
end = query_start_loc[i + 1].item()
|
value=value,
|
||||||
q_len = end - start
|
mask=mask,
|
||||||
|
seq_len=attn_metadata.seq_lens,
|
||||||
# Extract q, k, v (BSND)
|
scale_value=self.scale,
|
||||||
q = query[start:end].unsqueeze(0)
|
num_heads=self.num_heads,
|
||||||
k = key[start:end].unsqueeze(0)
|
num_kv_heads=self.num_kv_heads,
|
||||||
v = value[start:end].unsqueeze(0)
|
out=output,
|
||||||
|
|
||||||
# npu_fusion_attention: True = mask out (do NOT attend)
|
|
||||||
# Upper triangle = future tokens = should be masked out
|
|
||||||
attn_mask = torch.ones(
|
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
|
||||||
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
# Run npu_fusion_attention (BSND)
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
|
||||||
q, k, v,
|
|
||||||
head_num=self.num_heads,
|
|
||||||
input_layout="BSND",
|
|
||||||
scale=self.scale,
|
|
||||||
atten_mask=attn_mask,
|
|
||||||
pre_tockens=2147483647,
|
|
||||||
next_tockens=0,
|
|
||||||
)
|
)
|
||||||
|
return output[:num_tokens, :, :]
|
||||||
output[start:end] = attn_out[0]
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
|
# Chunked prefill — npu_fused_infer_attention_score (TND layout)
|
||||||
# (npu_fused_infer_attention_score requires 4D on older CANN)
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
|
|
||||||
def _forward_chunked_prefill(
|
def _forward_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Chunked prefill: decode tokens via npu_incre_flash_attention,
|
"""Chunked prefill / mixed prefill+decode via
|
||||||
prefill tokens via npu_fusion_attention per request."""
|
npu_fused_infer_attention_score, matching vllm-ascend's
|
||||||
|
_forward_v1_style."""
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
query_start_loc = attn_metadata.query_start_loc
|
assert self._key_cache is not None
|
||||||
seq_lens = attn_metadata.seq_lens
|
assert attn_metadata.attn_mask is not None
|
||||||
|
|
||||||
# Per-request query lengths
|
num_block, block_size, _, _ = self._key_cache.shape
|
||||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
key = self._key_cache.view(num_block, block_size, -1)
|
||||||
|
value = self._value_cache.view(num_block, block_size, -1)
|
||||||
|
|
||||||
decode_mask = query_lens == 1
|
# Trim query to actual tokens (npu_fused_infer_attention_score
|
||||||
prefill_mask = ~decode_mask
|
# requires query.shape[0] == query_start_loc[-1])
|
||||||
num_decodes = decode_mask.sum().item()
|
actual_num_tokens = attn_metadata.query_start_loc[-1]
|
||||||
|
q = query[:actual_num_tokens]
|
||||||
|
|
||||||
# --- Decode tokens ---
|
out, _ = torch_npu.npu_fused_infer_attention_score(
|
||||||
if num_decodes > 0 and self._key_cache is not None:
|
query=q,
|
||||||
decode_indices = torch.where(decode_mask)[0]
|
key=key,
|
||||||
decode_query = query[query_start_loc[decode_indices]]
|
value=value,
|
||||||
decode_block_tables = attn_metadata.block_tables[decode_indices]
|
atten_mask=attn_metadata.attn_mask,
|
||||||
decode_seq_lens = seq_lens[decode_indices].tolist()
|
block_table=attn_metadata.block_tables,
|
||||||
|
input_layout="TND",
|
||||||
decode_out = torch_npu.npu_incre_flash_attention(
|
block_size=block_size,
|
||||||
decode_query.unsqueeze(2),
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||||
self._key_cache,
|
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
||||||
self._value_cache,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
num_key_value_heads=self.num_kv_heads,
|
||||||
scale_value=self.scale,
|
num_heads=self.num_heads,
|
||||||
block_table=decode_block_tables,
|
|
||||||
actual_seq_lengths=decode_seq_lens,
|
|
||||||
block_size=self._key_cache.shape[1],
|
|
||||||
input_layout="BNSD",
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, idx in enumerate(decode_indices):
|
|
||||||
token_pos = query_start_loc[idx].item()
|
|
||||||
output[token_pos] = decode_out[i].squeeze(1)
|
|
||||||
|
|
||||||
# --- Prefill tokens (per-request via npu_fusion_attention) ---
|
|
||||||
if prefill_mask.any():
|
|
||||||
prefill_indices = torch.where(prefill_mask)[0]
|
|
||||||
for idx in prefill_indices:
|
|
||||||
start = query_start_loc[idx].item()
|
|
||||||
end = query_start_loc[idx + 1].item()
|
|
||||||
q_len = end - start
|
|
||||||
kv_len = seq_lens[idx].item()
|
|
||||||
q = query[start:end]
|
|
||||||
|
|
||||||
if self._key_cache is not None and kv_len > q_len:
|
|
||||||
# Gather KV from paged cache
|
|
||||||
block_table = attn_metadata.block_tables[idx]
|
|
||||||
bs = self._key_cache.shape[1]
|
|
||||||
num_blocks_needed = (kv_len + bs - 1) // bs
|
|
||||||
block_ids = block_table[:num_blocks_needed]
|
|
||||||
|
|
||||||
gathered_k = self._key_cache[block_ids].reshape(
|
|
||||||
-1, self.num_kv_heads, self.head_size
|
|
||||||
)[:kv_len]
|
|
||||||
gathered_v = self._value_cache[block_ids].reshape(
|
|
||||||
-1, self.num_kv_heads, self.head_size
|
|
||||||
)[:kv_len]
|
|
||||||
|
|
||||||
# npu_fusion_attention: True = mask out
|
|
||||||
# For chunked prefill, mask future positions
|
|
||||||
causal_mask = torch.ones(
|
|
||||||
q_len, kv_len, dtype=torch.bool, device=query.device
|
|
||||||
).triu_(diagonal=kv_len - q_len + 1)
|
|
||||||
# logic for chunked prefill mask (non-square)?
|
|
||||||
# If q_len < kv_len (prefill extension), mask logic is complex.
|
|
||||||
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
|
|
||||||
# tril with diagonal adjustment.
|
|
||||||
# diagonal=kv_len - q_len ensures main diagonal alignment.
|
|
||||||
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
|
||||||
q.unsqueeze(0),
|
|
||||||
gathered_k.unsqueeze(0),
|
|
||||||
gathered_v.unsqueeze(0),
|
|
||||||
head_num=self.num_heads,
|
|
||||||
input_layout="BSND",
|
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
sparse_mode=0,
|
sparse_mode=3,
|
||||||
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
|
|
||||||
pre_tockens=kv_len,
|
|
||||||
next_tockens=0,
|
|
||||||
)
|
)
|
||||||
output[start:end] = attn_out[0].squeeze(0)
|
|
||||||
else:
|
|
||||||
# Full self-attention (no prior cache)
|
|
||||||
k = key[start:end]
|
|
||||||
v = value[start:end]
|
|
||||||
|
|
||||||
# npu_fusion_attention: True = mask out
|
|
||||||
causal_mask = torch.ones(
|
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
|
||||||
).triu_(diagonal=1)
|
|
||||||
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
|
||||||
q.unsqueeze(0),
|
|
||||||
k.unsqueeze(0),
|
|
||||||
v.unsqueeze(0),
|
|
||||||
head_num=self.num_heads,
|
|
||||||
input_layout="BSND",
|
|
||||||
scale=self.scale,
|
|
||||||
sparse_mode=0,
|
|
||||||
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
|
|
||||||
pre_tockens=q_len,
|
|
||||||
next_tockens=0,
|
|
||||||
)
|
|
||||||
output[start:end] = attn_out[0].squeeze(0)
|
|
||||||
|
|
||||||
|
output[:actual_num_tokens, :, :] = out[:actual_num_tokens, :, :]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user