diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 3669db6..1da282e 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -468,9 +468,11 @@ class AscendAttentionBackendImpl(AttentionImpl): v = value[start:end].unsqueeze(0) # Mask (lower triangular for causal) + # npu_fusion_attention requires mask dim to be 2 or 4. + # We use (B, 1, S, S) -> (1, 1, S, S) attn_mask = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1).unsqueeze(0) + ).triu_(diagonal=1).unsqueeze(0).unsqueeze(0) # Run npu_fusion_attention (BSND) attn_out = torch_npu.npu_fusion_attention( @@ -575,7 +577,7 @@ class AscendAttentionBackendImpl(AttentionImpl): input_layout="BSND", scale=self.scale, sparse_mode=0, - atten_mask=causal_mask.unsqueeze(0), + atten_mask=causal_mask.unsqueeze(0).unsqueeze(0), pre_tockens=kv_len, next_tockens=0, ) @@ -596,7 +598,7 @@ class AscendAttentionBackendImpl(AttentionImpl): input_layout="BSND", scale=self.scale, sparse_mode=0, - atten_mask=causal_mask.unsqueeze(0), + atten_mask=causal_mask.unsqueeze(0).unsqueeze(0), pre_tockens=q_len, next_tockens=0, )