fix: correct layout for npu_incre_flash_attention (BNSD requires B,H,1,D)

This commit is contained in:
2026-02-10 20:23:03 +08:00
parent e7655a0745
commit a58c3fe973

View File

@@ -419,7 +419,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
"""Decode-only via npu_incre_flash_attention.""" """Decode-only via npu_incre_flash_attention."""
import torch_npu # noqa: F401 import torch_npu # noqa: F401
q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D) q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD
attn_out = torch_npu.npu_incre_flash_attention( attn_out = torch_npu.npu_incre_flash_attention(
q, q,
@@ -434,7 +434,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
input_layout="BNSD", input_layout="BNSD",
) )
output[:num_tokens] = attn_out.squeeze(1) output[:num_tokens] = attn_out.squeeze(2)
return output return output
# ----------------------------------------------------------------- # -----------------------------------------------------------------
@@ -509,7 +509,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
decode_seq_lens = seq_lens[decode_indices].tolist() decode_seq_lens = seq_lens[decode_indices].tolist()
decode_out = torch_npu.npu_incre_flash_attention( decode_out = torch_npu.npu_incre_flash_attention(
decode_query.unsqueeze(1), decode_query.unsqueeze(2),
self._key_cache, self._key_cache,
self._value_cache, self._value_cache,
num_heads=self.num_heads, num_heads=self.num_heads,
@@ -523,7 +523,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
for i, idx in enumerate(decode_indices): for i, idx in enumerate(decode_indices):
token_pos = query_start_loc[idx].item() token_pos = query_start_loc[idx].item()
output[token_pos] = decode_out[i].squeeze(0) output[token_pos] = decode_out[i].squeeze(1)
# --- Prefill tokens (per-request via npu_fusion_attention) --- # --- Prefill tokens (per-request via npu_fusion_attention) ---
if prefill_mask.any(): if prefill_mask.any():