fix: correct layout for npu_incre_flash_attention (BNSD requires B,H,1,D)

This commit is contained in:
2026-02-10 20:23:03 +08:00
parent e7655a0745
commit a58c3fe973

View File

@@ -419,7 +419,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
"""Decode-only via npu_incre_flash_attention."""
import torch_npu # noqa: F401
q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D)
q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD
attn_out = torch_npu.npu_incre_flash_attention(
q,
@@ -434,7 +434,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
input_layout="BNSD",
)
output[:num_tokens] = attn_out.squeeze(1)
output[:num_tokens] = attn_out.squeeze(2)
return output
# -----------------------------------------------------------------
@@ -509,7 +509,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
decode_seq_lens = seq_lens[decode_indices].tolist()
decode_out = torch_npu.npu_incre_flash_attention(
decode_query.unsqueeze(1),
decode_query.unsqueeze(2),
self._key_cache,
self._value_cache,
num_heads=self.num_heads,
@@ -523,7 +523,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
for i, idx in enumerate(decode_indices):
token_pos = query_start_loc[idx].item()
output[token_pos] = decode_out[i].squeeze(0)
output[token_pos] = decode_out[i].squeeze(1)
# --- Prefill tokens (per-request via npu_fusion_attention) ---
if prefill_mask.any():