mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: use 4D mask (1, 1, S, S) for BSND layout in npu_fusion_attention
This commit is contained in:
@@ -468,9 +468,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
v = value[start:end].unsqueeze(0)
|
v = value[start:end].unsqueeze(0)
|
||||||
|
|
||||||
# Mask (lower triangular for causal)
|
# Mask (lower triangular for causal)
|
||||||
|
# npu_fusion_attention requires mask dim to be 2 or 4.
|
||||||
|
# We use (B, 1, S, S) -> (1, 1, S, S)
|
||||||
attn_mask = torch.ones(
|
attn_mask = torch.ones(
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
q_len, q_len, dtype=torch.bool, device=query.device
|
||||||
).triu_(diagonal=1).unsqueeze(0)
|
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
# Run npu_fusion_attention (BSND)
|
# Run npu_fusion_attention (BSND)
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
@@ -575,7 +577,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
input_layout="BSND",
|
input_layout="BSND",
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
sparse_mode=0,
|
sparse_mode=0,
|
||||||
atten_mask=causal_mask.unsqueeze(0),
|
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
|
||||||
pre_tockens=kv_len,
|
pre_tockens=kv_len,
|
||||||
next_tockens=0,
|
next_tockens=0,
|
||||||
)
|
)
|
||||||
@@ -596,7 +598,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
input_layout="BSND",
|
input_layout="BSND",
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
sparse_mode=0,
|
sparse_mode=0,
|
||||||
atten_mask=causal_mask.unsqueeze(0),
|
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
|
||||||
pre_tockens=q_len,
|
pre_tockens=q_len,
|
||||||
next_tockens=0,
|
next_tockens=0,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user