mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: use additive float mask (-inf) for npu_fusion_attention to resolve garbage output
This commit is contained in:
@@ -467,12 +467,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
k = key[start:end].unsqueeze(0)
|
k = key[start:end].unsqueeze(0)
|
||||||
v = value[start:end].unsqueeze(0)
|
v = value[start:end].unsqueeze(0)
|
||||||
|
|
||||||
# Mask (lower triangular for causal)
|
# Create additive mask (0 for keep, -inf for mask)
|
||||||
# npu_fusion_attention requires mask dim to be 2 or 4.
|
inf_value = float("-inf")
|
||||||
# We use (B, 1, S, S) -> (1, 1, S, S)
|
mask_bool = torch.ones(
|
||||||
attn_mask = torch.ones(
|
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
q_len, q_len, dtype=torch.bool, device=query.device
|
||||||
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
|
).triu_(diagonal=1)
|
||||||
|
attn_mask = torch.zeros(
|
||||||
|
q_len, q_len, dtype=query.dtype, device=query.device
|
||||||
|
).masked_fill_(mask_bool, inf_value).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
# Run npu_fusion_attention (BSND)
|
# Run npu_fusion_attention (BSND)
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
@@ -565,10 +567,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
-1, self.num_kv_heads, self.head_size
|
-1, self.num_kv_heads, self.head_size
|
||||||
)[:kv_len]
|
)[:kv_len]
|
||||||
|
|
||||||
causal_mask = torch.ones(
|
inf_value = float("-inf")
|
||||||
|
mask_bool = torch.ones(
|
||||||
q_len, kv_len, dtype=torch.bool, device=query.device
|
q_len, kv_len, dtype=torch.bool, device=query.device
|
||||||
).triu_(diagonal=kv_len - q_len + 1)
|
).triu_(diagonal=kv_len - q_len + 1)
|
||||||
|
|
||||||
|
causal_mask = torch.zeros(
|
||||||
|
q_len, kv_len, dtype=query.dtype, device=query.device
|
||||||
|
).masked_fill_(mask_bool, inf_value)
|
||||||
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
q.unsqueeze(0),
|
q.unsqueeze(0),
|
||||||
gathered_k.unsqueeze(0),
|
gathered_k.unsqueeze(0),
|
||||||
@@ -586,10 +593,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
# Full self-attention (no prior cache)
|
# Full self-attention (no prior cache)
|
||||||
k = key[start:end]
|
k = key[start:end]
|
||||||
v = value[start:end]
|
v = value[start:end]
|
||||||
causal_mask = torch.ones(
|
|
||||||
|
inf_value = float("-inf")
|
||||||
|
mask_bool = torch.ones(
|
||||||
q_len, q_len, dtype=torch.bool, device=query.device
|
q_len, q_len, dtype=torch.bool, device=query.device
|
||||||
).triu_(diagonal=1)
|
).triu_(diagonal=1)
|
||||||
|
|
||||||
|
causal_mask = torch.zeros(
|
||||||
|
q_len, q_len, dtype=query.dtype, device=query.device
|
||||||
|
).masked_fill_(mask_bool, inf_value)
|
||||||
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
q.unsqueeze(0),
|
q.unsqueeze(0),
|
||||||
k.unsqueeze(0),
|
k.unsqueeze(0),
|
||||||
|
|||||||
Reference in New Issue
Block a user