diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 1da282e..781a707 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -467,12 +467,14 @@ class AscendAttentionBackendImpl(AttentionImpl): k = key[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0) - # Mask (lower triangular for causal) - # npu_fusion_attention requires mask dim to be 2 or 4. - # We use (B, 1, S, S) -> (1, 1, S, S) - attn_mask = torch.ones( + # Create additive mask (0 for keep, -inf for mask) + inf_value = float("-inf") + mask_bool = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1).unsqueeze(0).unsqueeze(0) + ).triu_(diagonal=1) + attn_mask = torch.zeros( + q_len, q_len, dtype=query.dtype, device=query.device + ).masked_fill_(mask_bool, inf_value).unsqueeze(0).unsqueeze(0) # Run npu_fusion_attention (BSND) attn_out = torch_npu.npu_fusion_attention( @@ -565,9 +567,14 @@ class AscendAttentionBackendImpl(AttentionImpl): -1, self.num_kv_heads, self.head_size )[:kv_len] - causal_mask = torch.ones( + inf_value = float("-inf") + mask_bool = torch.ones( q_len, kv_len, dtype=torch.bool, device=query.device ).triu_(diagonal=kv_len - q_len + 1) + + causal_mask = torch.zeros( + q_len, kv_len, dtype=query.dtype, device=query.device + ).masked_fill_(mask_bool, inf_value) attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0), @@ -586,9 +593,15 @@ class AscendAttentionBackendImpl(AttentionImpl): # Full self-attention (no prior cache) k = key[start:end] v = value[start:end] - causal_mask = torch.ones( + + inf_value = float("-inf") + mask_bool = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device ).triu_(diagonal=1) + + causal_mask = torch.zeros( + q_len, q_len, dtype=query.dtype, device=query.device + ).masked_fill_(mask_bool, inf_value) attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0),