fix: proper PrefillNoCache detection, fallback to npu_fusion_attention for chunked prefill (CANN compat)

This commit is contained in:
2026-02-10 20:14:42 +08:00
parent 810a2ef757
commit e7655a0745

View File

@@ -202,18 +202,26 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
max_query_len = common_attn_metadata.max_query_len max_query_len = common_attn_metadata.max_query_len
# Determine attention state # Determine attention state
num_reqs = common_attn_metadata.num_reqs
if max_query_len == 1: if max_query_len == 1:
attn_state = AscendAttentionState.DecodeOnly attn_state = AscendAttentionState.DecodeOnly
else: else:
attn_state = AscendAttentionState.ChunkedPrefill # Check if this is a pure prefill (no prior cache) or chunked
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs]
# PrefillNoCache: all requests have query_len == seq_len
if (query_lens_cpu == seq_lens_cpu).all():
attn_state = AscendAttentionState.PrefillNoCache
else:
attn_state = AscendAttentionState.ChunkedPrefill
# Build cumulative sequence lengths for query (for prefill) # Build cumulative sequence lengths for query (for prefill)
num_reqs = common_attn_metadata.num_reqs query_start_loc_cpu_full = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc.to( query_start_loc = common_attn_metadata.query_start_loc.to(
dtype=torch.int64 dtype=torch.int64
) )
actual_seq_lengths_q = query_start_loc_cpu[1:].tolist() actual_seq_lengths_q = query_start_loc_cpu_full[1:].tolist()
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist()
@@ -333,30 +341,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
return key, value return key, value
# ----------------------------------------------------------------- # -----------------------------------------------------------------
# Forward dispatch (matches Huawei vllm-ascend structure) # Forward dispatch
# ----------------------------------------------------------------- # -----------------------------------------------------------------
def _get_fia_params(
self,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: "AscendMetadata",
):
"""Prepare key, value, block_size, block_table and kv_seq_lens
for npu_fused_infer_attention_score, following Huawei's approach."""
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc[1:].tolist()
else:
# DecodeOnly / PrefillCacheHit / ChunkedPrefill — read from cache
num_block, block_size, _, _ = self._key_cache.shape
key = self._key_cache.view(num_block, block_size, -1)
value = self._value_cache.view(num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
return key, value, block_size, block_table, actual_seq_lengths_kv
def forward( def forward(
self, self,
layer: nn.Module, layer: nn.Module,
@@ -411,9 +398,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
query, key, value, attn_metadata, output, num_tokens query, key, value, attn_metadata, output, num_tokens
) )
else: else:
# ChunkedPrefill or PrefillCacheHit — use FIA with block tables # ChunkedPrefill or PrefillCacheHit
output = self._forward_fused_infer_attention( output = self._forward_chunked_prefill(
query, key, value, attn_metadata, output query, key, value, attn_metadata, output, num_tokens
) )
return output return output
@@ -487,45 +474,119 @@ class AscendAttentionBackendImpl(AttentionImpl):
return output return output
# ----------------------------------------------------------------- # -----------------------------------------------------------------
# Fused Infer Attention (prefill with cache / chunked prefill) # Chunked prefill — mixed prefill+decode via npu_fusion_attention
# Matches Huawei's forward_fused_infer_attention approach # (npu_fused_infer_attention_score requires 4D on older CANN)
# ----------------------------------------------------------------- # -----------------------------------------------------------------
def _forward_fused_infer_attention( def _forward_chunked_prefill(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: torch.Tensor, output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor: ) -> torch.Tensor:
"""Use npu_fused_infer_attention_score with TND layout and block """Chunked prefill: decode tokens via npu_incre_flash_attention,
tables — the same approach Huawei uses for chunked prefill and prefill tokens via npu_fusion_attention per request."""
cache-hit prefill."""
import torch_npu # noqa: F401 import torch_npu # noqa: F401
key, value, block_size, block_table, actual_seq_lengths_kv = ( query_start_loc = attn_metadata.query_start_loc
self._get_fia_params(key, value, attn_metadata) seq_lens = attn_metadata.seq_lens
)
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
query = query[:num_tokens]
attn_output, _ = torch_npu.npu_fused_infer_attention_score( # Per-request query lengths
query=query, query_lens = query_start_loc[1:] - query_start_loc[:-1]
key=key,
value=value, decode_mask = query_lens == 1
atten_mask=attn_metadata.attn_mask, prefill_mask = ~decode_mask
block_table=block_table, num_decodes = decode_mask.sum().item()
input_layout="TND",
block_size=block_size, # --- Decode tokens ---
actual_seq_lengths=attn_metadata.actual_seq_lengths_q, if num_decodes > 0 and self._key_cache is not None:
actual_seq_lengths_kv=actual_seq_lengths_kv, decode_indices = torch.where(decode_mask)[0]
num_key_value_heads=self.num_kv_heads, decode_query = query[query_start_loc[decode_indices]]
num_heads=self.num_heads, decode_block_tables = attn_metadata.block_tables[decode_indices]
scale=self.scale, decode_seq_lens = seq_lens[decode_indices].tolist()
sparse_mode=3,
) decode_out = torch_npu.npu_incre_flash_attention(
decode_query.unsqueeze(1),
self._key_cache,
self._value_cache,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale_value=self.scale,
block_table=decode_block_tables,
actual_seq_lengths=decode_seq_lens,
block_size=self._key_cache.shape[1],
input_layout="BNSD",
)
for i, idx in enumerate(decode_indices):
token_pos = query_start_loc[idx].item()
output[token_pos] = decode_out[i].squeeze(0)
# --- Prefill tokens (per-request via npu_fusion_attention) ---
if prefill_mask.any():
prefill_indices = torch.where(prefill_mask)[0]
for idx in prefill_indices:
start = query_start_loc[idx].item()
end = query_start_loc[idx + 1].item()
q_len = end - start
kv_len = seq_lens[idx].item()
q = query[start:end]
if self._key_cache is not None and kv_len > q_len:
# Gather KV from paged cache
block_table = attn_metadata.block_tables[idx]
bs = self._key_cache.shape[1]
num_blocks_needed = (kv_len + bs - 1) // bs
block_ids = block_table[:num_blocks_needed]
gathered_k = self._key_cache[block_ids].reshape(
-1, self.num_kv_heads, self.head_size
)[:kv_len]
gathered_v = self._value_cache[block_ids].reshape(
-1, self.num_kv_heads, self.head_size
)[:kv_len]
causal_mask = torch.ones(
q_len, kv_len, dtype=torch.bool, device=query.device
).triu_(diagonal=kv_len - q_len + 1)
attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0),
gathered_k.unsqueeze(0),
gathered_v.unsqueeze(0),
head_num=self.num_heads,
input_layout="BSND",
scale=self.scale,
sparse_mode=0,
atten_mask=causal_mask.unsqueeze(0),
pre_tockens=kv_len,
next_tockens=0,
)
output[start:end] = attn_out[0].squeeze(0)
else:
# Full self-attention (no prior cache)
k = key[start:end]
v = value[start:end]
causal_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1)
attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
head_num=self.num_heads,
input_layout="BSND",
scale=self.scale,
sparse_mode=0,
atten_mask=causal_mask.unsqueeze(0),
pre_tockens=q_len,
next_tockens=0,
)
output[start:end] = attn_out[0].squeeze(0)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output[:num_tokens]
return output return output