fix: use npu_fusion_attention loop (BSND) for prefill_no_cache to fix crash

This commit is contained in:
2026-02-10 20:42:47 +08:00
parent 5337842e92
commit 37af1ddc1f

View File

@@ -449,33 +449,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor, output: torch.Tensor,
num_tokens: int, num_tokens: int,
) -> torch.Tensor: ) -> torch.Tensor:
"""Prefill attention without KV cache (self-attention) using _npu_flash_attention.""" """Prefill attention without KV cache (self-attention) via per-req loop."""
import torch_npu # noqa: F401 import torch_npu # noqa: F401
# Huawei uses _npu_flash_attention for prefill query_start_loc = attn_metadata.query_start_loc
# Ensure contiguous inputs seq_lens = attn_metadata.seq_lens
query = query.contiguous() num_reqs = len(seq_lens)
key = key.contiguous()
value = value.contiguous()
# mask needs to be contiguous and cast to expected format if needed # Iterate and process each request independently to bypass TND issues
# but _npu_flash_attention handles generic mask? for i in range(num_reqs):
# Huawei code: mask = attn_metadata.attn_mask... start = query_start_loc[i].item()
# We'll pass it as is, assuming AscendMetadataBuilder created it correctly. end = query_start_loc[i + 1].item()
q_len = end - start
torch_npu._npu_flash_attention( # Extract q, k, v (BSND)
query=query, q = query[start:end].unsqueeze(0)
key=key, k = key[start:end].unsqueeze(0)
value=value, v = value[start:end].unsqueeze(0)
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output
)
return output[:num_tokens] # Mask (lower triangular for causal)
attn_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1).unsqueeze(0)
# Run npu_fusion_attention (BSND)
attn_out = torch_npu.npu_fusion_attention(
q, k, v,
head_num=self.num_heads,
input_layout="BSND",
scale=self.scale,
atten_mask=attn_mask,
pre_tockens=2147483647,
next_tockens=0,
)
output[start:end] = attn_out[0]
return output
# ----------------------------------------------------------------- # -----------------------------------------------------------------
# Chunked prefill — mixed prefill+decode via npu_fusion_attention # Chunked prefill — mixed prefill+decode via npu_fusion_attention