From a58c3fe973a5220f06c7d6ac6402357011184a87 Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 20:23:03 +0800 Subject: [PATCH] fix: correct layout for npu_incre_flash_attention (BNSD requires B,H,1,D) --- vllm_npu/attention/attention_v1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 8afeb1a..7b3cfbd 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -419,7 +419,7 @@ class AscendAttentionBackendImpl(AttentionImpl): """Decode-only via npu_incre_flash_attention.""" import torch_npu # noqa: F401 - q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D) + q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD attn_out = torch_npu.npu_incre_flash_attention( q, @@ -434,7 +434,7 @@ class AscendAttentionBackendImpl(AttentionImpl): input_layout="BNSD", ) - output[:num_tokens] = attn_out.squeeze(1) + output[:num_tokens] = attn_out.squeeze(2) return output # ----------------------------------------------------------------- @@ -509,7 +509,7 @@ class AscendAttentionBackendImpl(AttentionImpl): decode_seq_lens = seq_lens[decode_indices].tolist() decode_out = torch_npu.npu_incre_flash_attention( - decode_query.unsqueeze(1), + decode_query.unsqueeze(2), self._key_cache, self._value_cache, num_heads=self.num_heads, @@ -523,7 +523,7 @@ class AscendAttentionBackendImpl(AttentionImpl): for i, idx in enumerate(decode_indices): token_pos = query_start_loc[idx].item() - output[token_pos] = decode_out[i].squeeze(0) + output[token_pos] = decode_out[i].squeeze(1) # --- Prefill tokens (per-request via npu_fusion_attention) --- if prefill_mask.any():