fix: use 4D mask (1, 1, S, S) for BSND layout in npu_fusion_attention

This commit is contained in:
2026-02-10 20:57:52 +08:00
parent 37af1ddc1f
commit f54533fba7

View File

@@ -468,9 +468,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
v = value[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0)
# Mask (lower triangular for causal) # Mask (lower triangular for causal)
# npu_fusion_attention requires mask dim to be 2 or 4.
# We use (B, 1, S, S) -> (1, 1, S, S)
attn_mask = torch.ones( attn_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1).unsqueeze(0) ).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
# Run npu_fusion_attention (BSND) # Run npu_fusion_attention (BSND)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
@@ -575,7 +577,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
input_layout="BSND", input_layout="BSND",
scale=self.scale, scale=self.scale,
sparse_mode=0, sparse_mode=0,
atten_mask=causal_mask.unsqueeze(0), atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
pre_tockens=kv_len, pre_tockens=kv_len,
next_tockens=0, next_tockens=0,
) )
@@ -596,7 +598,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
input_layout="BSND", input_layout="BSND",
scale=self.scale, scale=self.scale,
sparse_mode=0, sparse_mode=0,
atten_mask=causal_mask.unsqueeze(0), atten_mask=causal_mask.unsqueeze(0).unsqueeze(0),
pre_tockens=q_len, pre_tockens=q_len,
next_tockens=0, next_tockens=0,
) )