feat: Add Ascend NPU attention backend for vLLM using FlashAttention operators.

This commit is contained in:
2026-02-10 22:15:26 +08:00
parent 5bef2da1f1
commit e22617f72e

View File

@@ -5,9 +5,9 @@ Implements the ``AttentionBackend``, ``AttentionMetadata``,
``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using
Huawei Ascend NPU FlashAttention operators:
- ``torch_npu.npu_fusion_attention`` — fused multi-head attention
- ``torch_npu._npu_flash_attention`` prefill attention (TND layout)
- ``torch_npu._npu_reshape_and_cache`` — KV cache update
- ``torch_npu.npu_incre_flash_attention`` — paged-attention decode
- ``torch_npu._npu_paged_attention`` — paged-attention decode
"""
from dataclasses import dataclass
@@ -319,30 +319,27 @@ class AscendAttentionBackendImpl(AttentionImpl):
):
"""Update KV cache with new key/value tensors.
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
slot_mapping indices.
Uses ``torch_npu._npu_reshape_and_cache`` for efficient in-place
KV cache update, matching vllm-ascend reference.
"""
import torch_npu # noqa: F401
if kv_cache.numel() > 0:
if self._key_cache is None:
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
key_to_cache = key[:attn_metadata.num_actual_tokens]
val_to_cache = value[:attn_metadata.num_actual_tokens]
# Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env)
block_size = self._key_cache.shape[1]
block_idx = slots // block_size
block_offset = slots % block_size
self._key_cache[block_idx, block_offset] = key_to_cache
self._value_cache[block_idx, block_offset] = val_to_cache
num_actual = attn_metadata.num_actual_tokens
torch_npu._npu_reshape_and_cache(
key=key[:num_actual],
value=value[:num_actual],
key_cache=self._key_cache,
value_cache=self._value_cache,
slot_indices=slots,
)
return key, value
# -----------------------------------------------------------------
# Forward dispatch
# -----------------------------------------------------------------
def forward(
self,
layer: nn.Module,
@@ -399,15 +396,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
query, key, value, attn_metadata, output, num_tokens
)
else:
# ChunkedPrefill or PrefillCacheHit
# ChunkedPrefill — use npu_fused_infer_attention_score
output = self._forward_chunked_prefill(
query, key, value, attn_metadata, output, num_tokens
query, attn_metadata, output, num_tokens
)
return output
# -----------------------------------------------------------------
# Decode path — paged attention (matches Huawei _npu_paged_attention)
# Decode path — paged attention via _npu_paged_attention
# -----------------------------------------------------------------
def _forward_decode(
@@ -417,29 +414,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Decode-only via npu_incre_flash_attention."""
"""Decode-only via _npu_paged_attention (matches vllm-ascend)."""
import torch_npu # noqa: F401
q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD
attn_out = torch_npu.npu_incre_flash_attention(
q,
self._key_cache,
self._value_cache,
torch_npu._npu_paged_attention(
query=query,
key_cache=self._key_cache,
value_cache=self._value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
actual_seq_lengths=attn_metadata.seq_lens_list,
block_size=self._key_cache.shape[1],
input_layout="BNSD",
context_lens=attn_metadata.seq_lens,
out=output,
)
output[:num_tokens] = attn_out.squeeze(2)
return output
# -----------------------------------------------------------------
# Prefill without KV cache
# Prefill without KV cache — _npu_flash_attention (TND layout)
# -----------------------------------------------------------------
def _forward_prefill_no_cache(
@@ -451,168 +443,72 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Prefill attention without KV cache (self-attention) via per-req loop."""
"""Prefill attention without KV cache via _npu_flash_attention.
Uses TND layout and a pre-built causal mask from metadata.
This matches vllm-ascend's _forward_prefill_no_cache.
"""
import torch_npu # noqa: F401
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
num_reqs = len(seq_lens)
mask = attn_metadata.attn_mask
# Iterate and process each request independently to bypass TND issues
for i in range(num_reqs):
start = query_start_loc[i].item()
end = query_start_loc[i + 1].item()
q_len = end - start
# Extract q, k, v (BSND)
q = query[start:end].unsqueeze(0)
k = key[start:end].unsqueeze(0)
v = value[start:end].unsqueeze(0)
# npu_fusion_attention: True = mask out (do NOT attend)
# Upper triangle = future tokens = should be masked out
attn_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
# Run npu_fusion_attention (BSND)
attn_out = torch_npu.npu_fusion_attention(
q, k, v,
head_num=self.num_heads,
input_layout="BSND",
scale=self.scale,
atten_mask=attn_mask,
pre_tockens=2147483647,
next_tockens=0,
)
output[start:end] = attn_out[0]
return output
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output,
)
return output[:num_tokens, :, :]
# -----------------------------------------------------------------
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
# (npu_fused_infer_attention_score requires 4D on older CANN)
# Chunked prefill — npu_fused_infer_attention_score (TND layout)
# -----------------------------------------------------------------
def _forward_chunked_prefill(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Chunked prefill: decode tokens via npu_incre_flash_attention,
prefill tokens via npu_fusion_attention per request."""
"""Chunked prefill / mixed prefill+decode via
npu_fused_infer_attention_score, matching vllm-ascend's
_forward_v1_style."""
import torch_npu # noqa: F401
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
assert self._key_cache is not None
assert attn_metadata.attn_mask is not None
# Per-request query lengths
query_lens = query_start_loc[1:] - query_start_loc[:-1]
num_block, block_size, _, _ = self._key_cache.shape
key = self._key_cache.view(num_block, block_size, -1)
value = self._value_cache.view(num_block, block_size, -1)
decode_mask = query_lens == 1
prefill_mask = ~decode_mask
num_decodes = decode_mask.sum().item()
# Trim query to actual tokens (npu_fused_infer_attention_score
# requires query.shape[0] == query_start_loc[-1])
actual_num_tokens = attn_metadata.query_start_loc[-1]
q = query[:actual_num_tokens]
# --- Decode tokens ---
if num_decodes > 0 and self._key_cache is not None:
decode_indices = torch.where(decode_mask)[0]
decode_query = query[query_start_loc[decode_indices]]
decode_block_tables = attn_metadata.block_tables[decode_indices]
decode_seq_lens = seq_lens[decode_indices].tolist()
decode_out = torch_npu.npu_incre_flash_attention(
decode_query.unsqueeze(2),
self._key_cache,
self._value_cache,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale_value=self.scale,
block_table=decode_block_tables,
actual_seq_lengths=decode_seq_lens,
block_size=self._key_cache.shape[1],
input_layout="BNSD",
)
for i, idx in enumerate(decode_indices):
token_pos = query_start_loc[idx].item()
output[token_pos] = decode_out[i].squeeze(1)
# --- Prefill tokens (per-request via npu_fusion_attention) ---
if prefill_mask.any():
prefill_indices = torch.where(prefill_mask)[0]
for idx in prefill_indices:
start = query_start_loc[idx].item()
end = query_start_loc[idx + 1].item()
q_len = end - start
kv_len = seq_lens[idx].item()
q = query[start:end]
if self._key_cache is not None and kv_len > q_len:
# Gather KV from paged cache
block_table = attn_metadata.block_tables[idx]
bs = self._key_cache.shape[1]
num_blocks_needed = (kv_len + bs - 1) // bs
block_ids = block_table[:num_blocks_needed]
gathered_k = self._key_cache[block_ids].reshape(
-1, self.num_kv_heads, self.head_size
)[:kv_len]
gathered_v = self._value_cache[block_ids].reshape(
-1, self.num_kv_heads, self.head_size
)[:kv_len]
# npu_fusion_attention: True = mask out
# For chunked prefill, mask future positions
causal_mask = torch.ones(
q_len, kv_len, dtype=torch.bool, device=query.device
).triu_(diagonal=kv_len - q_len + 1)
# logic for chunked prefill mask (non-square)?
# If q_len < kv_len (prefill extension), mask logic is complex.
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
# tril with diagonal adjustment.
# diagonal=kv_len - q_len ensures main diagonal alignment.
attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0),
gathered_k.unsqueeze(0),
gathered_v.unsqueeze(0),
head_num=self.num_heads,
input_layout="BSND",
scale=self.scale,
sparse_mode=0,
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
pre_tockens=kv_len,
next_tockens=0,
)
output[start:end] = attn_out[0].squeeze(0)
else:
# Full self-attention (no prior cache)
k = key[start:end]
v = value[start:end]
# npu_fusion_attention: True = mask out
causal_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1)
attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
head_num=self.num_heads,
input_layout="BSND",
scale=self.scale,
sparse_mode=0,
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
pre_tockens=q_len,
next_tockens=0,
)
output[start:end] = attn_out[0].squeeze(0)
out, _ = torch_npu.npu_fused_infer_attention_score(
query=q,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
output[:actual_num_tokens, :, :] = out[:actual_num_tokens, :, :]
return output