From 3aebca03d94e7793f19603792266a468f398e276 Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 21:26:42 +0800 Subject: [PATCH] feat: Add Ascend NPU attention backend utilizing torch_npu FlashAttention and KV cache operations. --- vllm_npu/attention/attention_v1.py | 37 +++++++++++++----------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 781a707..dc5bcce 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -467,14 +467,14 @@ class AscendAttentionBackendImpl(AttentionImpl): k = key[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0) - # Create additive mask (0 for keep, -inf for mask) - inf_value = float("-inf") - mask_bool = torch.ones( + # Create boolean mask (Lower triangle=True means Keep, Upper=False means Mask) + # npu_fusion_attention (sparse_mode=0) interprets True as Keep? + # Or if True=Mask, then tril masks Past (Garbage). + # But triu (Upper=True) produced Garbage. + # So we try tril (Lower=True). + attn_mask = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1) - attn_mask = torch.zeros( - q_len, q_len, dtype=query.dtype, device=query.device - ).masked_fill_(mask_bool, inf_value).unsqueeze(0).unsqueeze(0) + ).tril_(diagonal=0).unsqueeze(0).unsqueeze(0) # Run npu_fusion_attention (BSND) attn_out = torch_npu.npu_fusion_attention( @@ -567,15 +567,15 @@ class AscendAttentionBackendImpl(AttentionImpl): -1, self.num_kv_heads, self.head_size )[:kv_len] - inf_value = float("-inf") - mask_bool = torch.ones( + causal_mask = torch.ones( q_len, kv_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=kv_len - q_len + 1) + ).tril_(diagonal=kv_len - q_len) # Adjusted for offset? Or just simple? + # logic for chunked prefill mask (non-square)? + # If q_len < kv_len (prefill extension), mask logic is complex. + # Usually: mask[i, j] = True if j <= i + (kv_len - q_len). + # tril with diagonal adjustment. + # diagonal=kv_len - q_len ensures main diagonal alignment. - causal_mask = torch.zeros( - q_len, kv_len, dtype=query.dtype, device=query.device - ).masked_fill_(mask_bool, inf_value) - attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0), gathered_k.unsqueeze(0), @@ -594,14 +594,9 @@ class AscendAttentionBackendImpl(AttentionImpl): k = key[start:end] v = value[start:end] - inf_value = float("-inf") - mask_bool = torch.ones( + causal_mask = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1) - - causal_mask = torch.zeros( - q_len, q_len, dtype=query.dtype, device=query.device - ).masked_fill_(mask_bool, inf_value) + ).tril_(diagonal=0) attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0),