mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
fix: correct layout for npu_incre_flash_attention (BNSD requires B,H,1,D)
This commit is contained in:
@@ -419,7 +419,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
"""Decode-only via npu_incre_flash_attention."""
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D)
|
||||
q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD
|
||||
|
||||
attn_out = torch_npu.npu_incre_flash_attention(
|
||||
q,
|
||||
@@ -434,7 +434,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
input_layout="BNSD",
|
||||
)
|
||||
|
||||
output[:num_tokens] = attn_out.squeeze(1)
|
||||
output[:num_tokens] = attn_out.squeeze(2)
|
||||
return output
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
@@ -509,7 +509,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
decode_seq_lens = seq_lens[decode_indices].tolist()
|
||||
|
||||
decode_out = torch_npu.npu_incre_flash_attention(
|
||||
decode_query.unsqueeze(1),
|
||||
decode_query.unsqueeze(2),
|
||||
self._key_cache,
|
||||
self._value_cache,
|
||||
num_heads=self.num_heads,
|
||||
@@ -523,7 +523,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
for i, idx in enumerate(decode_indices):
|
||||
token_pos = query_start_loc[idx].item()
|
||||
output[token_pos] = decode_out[i].squeeze(0)
|
||||
output[token_pos] = decode_out[i].squeeze(1)
|
||||
|
||||
# --- Prefill tokens (per-request via npu_fusion_attention) ---
|
||||
if prefill_mask.any():
|
||||
|
||||
Reference in New Issue
Block a user