mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
fix: use 4D mask (1, 1, S, S) for BSND layout in npu_fusion_attention
This commit is contained in:
@@ -468,9 +468,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
v = value[start:end].unsqueeze(0)
|
||||
|
||||
# Mask (lower triangular for causal)
|
||||
# npu_fusion_attention requires mask dim to be 2 or 4.
|
||||
# We use (B, 1, S, S) -> (1, 1, S, S)
|
||||
attn_mask = torch.ones(
|
||||
q_len, q_len, dtype=torch.bool, device=query.device
|
||||
).triu_(diagonal=1).unsqueeze(0)
|
||||
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Run npu_fusion_attention (BSND)
|
||||
attn_out = torch_npu.npu_fusion_attention(
|
||||
@@ -575,7 +577,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
input_layout="BSND",
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
atten_mask=causal_mask.unsqueeze(0),
|
||||
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
|
||||
pre_tockens=kv_len,
|
||||
next_tockens=0,
|
||||
)
|
||||
@@ -596,7 +598,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
input_layout="BSND",
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
atten_mask=causal_mask.unsqueeze(0),
|
||||
atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
|
||||
pre_tockens=q_len,
|
||||
next_tockens=0,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user