mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: correct layout for npu_incre_flash_attention (BNSD requires B,H,1,D)
This commit is contained in:
@@ -419,7 +419,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
"""Decode-only via npu_incre_flash_attention."""
|
"""Decode-only via npu_incre_flash_attention."""
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D)
|
q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD
|
||||||
|
|
||||||
attn_out = torch_npu.npu_incre_flash_attention(
|
attn_out = torch_npu.npu_incre_flash_attention(
|
||||||
q,
|
q,
|
||||||
@@ -434,7 +434,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
input_layout="BNSD",
|
input_layout="BNSD",
|
||||||
)
|
)
|
||||||
|
|
||||||
output[:num_tokens] = attn_out.squeeze(1)
|
output[:num_tokens] = attn_out.squeeze(2)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
@@ -509,7 +509,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
decode_seq_lens = seq_lens[decode_indices].tolist()
|
decode_seq_lens = seq_lens[decode_indices].tolist()
|
||||||
|
|
||||||
decode_out = torch_npu.npu_incre_flash_attention(
|
decode_out = torch_npu.npu_incre_flash_attention(
|
||||||
decode_query.unsqueeze(1),
|
decode_query.unsqueeze(2),
|
||||||
self._key_cache,
|
self._key_cache,
|
||||||
self._value_cache,
|
self._value_cache,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@@ -523,7 +523,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
for i, idx in enumerate(decode_indices):
|
for i, idx in enumerate(decode_indices):
|
||||||
token_pos = query_start_loc[idx].item()
|
token_pos = query_start_loc[idx].item()
|
||||||
output[token_pos] = decode_out[i].squeeze(0)
|
output[token_pos] = decode_out[i].squeeze(1)
|
||||||
|
|
||||||
# --- Prefill tokens (per-request via npu_fusion_attention) ---
|
# --- Prefill tokens (per-request via npu_fusion_attention) ---
|
||||||
if prefill_mask.any():
|
if prefill_mask.any():
|
||||||
|
|||||||
Reference in New Issue
Block a user