fix: use additive float mask (-inf) for npu_fusion_attention to resolve garbage output

This commit is contained in:
2026-02-10 21:16:03 +08:00
parent f54533fba7
commit 71fdf46880

View File

@@ -467,12 +467,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
k = key[start:end].unsqueeze(0) k = key[start:end].unsqueeze(0)
v = value[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0)
# Mask (lower triangular for causal) # Create additive mask (0 for keep, -inf for mask)
# npu_fusion_attention requires mask dim to be 2 or 4. inf_value = float("-inf")
# We use (B, 1, S, S) -> (1, 1, S, S) mask_bool = torch.ones(
attn_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0) ).triu_(diagonal=1)
attn_mask = torch.zeros(
q_len, q_len, dtype=query.dtype, device=query.device
).masked_fill_(mask_bool, inf_value).unsqueeze(0).unsqueeze(0)
# Run npu_fusion_attention (BSND) # Run npu_fusion_attention (BSND)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
@@ -565,9 +567,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
-1, self.num_kv_heads, self.head_size -1, self.num_kv_heads, self.head_size
)[:kv_len] )[:kv_len]
causal_mask = torch.ones( inf_value = float("-inf")
mask_bool = torch.ones(
q_len, kv_len, dtype=torch.bool, device=query.device q_len, kv_len, dtype=torch.bool, device=query.device
).triu_(diagonal=kv_len - q_len + 1) ).triu_(diagonal=kv_len - q_len + 1)
causal_mask = torch.zeros(
q_len, kv_len, dtype=query.dtype, device=query.device
).masked_fill_(mask_bool, inf_value)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0), q.unsqueeze(0),
@@ -586,9 +593,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
# Full self-attention (no prior cache) # Full self-attention (no prior cache)
k = key[start:end] k = key[start:end]
v = value[start:end] v = value[start:end]
causal_mask = torch.ones(
inf_value = float("-inf")
mask_bool = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1) ).triu_(diagonal=1)
causal_mask = torch.zeros(
q_len, q_len, dtype=query.dtype, device=query.device
).masked_fill_(mask_bool, inf_value)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0), q.unsqueeze(0),