feat: Add Ascend NPU attention backend utilizing torch_npu FlashAttention and KV cache operations.

This commit is contained in:
2026-02-10 21:26:42 +08:00
parent 71fdf46880
commit 3aebca03d9

View File

@@ -467,14 +467,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
k = key[start:end].unsqueeze(0) k = key[start:end].unsqueeze(0)
v = value[start:end].unsqueeze(0) v = value[start:end].unsqueeze(0)
# Create additive mask (0 for keep, -inf for mask) # Create boolean mask (Lower triangle=True means Keep, Upper=False means Mask)
inf_value = float("-inf") # npu_fusion_attention (sparse_mode=0) interprets True as Keep?
mask_bool = torch.ones( # Or if True=Mask, then tril masks Past (Garbage).
# But triu (Upper=True) produced Garbage.
# So we try tril (Lower=True).
attn_mask = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1) ).tril_(diagonal=0).unsqueeze(0).unsqueeze(0)
attn_mask = torch.zeros(
q_len, q_len, dtype=query.dtype, device=query.device
).masked_fill_(mask_bool, inf_value).unsqueeze(0).unsqueeze(0)
# Run npu_fusion_attention (BSND) # Run npu_fusion_attention (BSND)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
@@ -567,15 +567,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
-1, self.num_kv_heads, self.head_size -1, self.num_kv_heads, self.head_size
)[:kv_len] )[:kv_len]
inf_value = float("-inf") causal_mask = torch.ones(
mask_bool = torch.ones(
q_len, kv_len, dtype=torch.bool, device=query.device q_len, kv_len, dtype=torch.bool, device=query.device
).triu_(diagonal=kv_len - q_len + 1) ).tril_(diagonal=kv_len - q_len) # Adjusted for offset? Or just simple?
# logic for chunked prefill mask (non-square)?
# If q_len < kv_len (prefill extension), mask logic is complex.
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
# tril with diagonal adjustment.
# diagonal=kv_len - q_len ensures main diagonal alignment.
causal_mask = torch.zeros(
q_len, kv_len, dtype=query.dtype, device=query.device
).masked_fill_(mask_bool, inf_value)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0), q.unsqueeze(0),
gathered_k.unsqueeze(0), gathered_k.unsqueeze(0),
@@ -594,14 +594,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
k = key[start:end] k = key[start:end]
v = value[start:end] v = value[start:end]
inf_value = float("-inf") causal_mask = torch.ones(
mask_bool = torch.ones(
q_len, q_len, dtype=torch.bool, device=query.device q_len, q_len, dtype=torch.bool, device=query.device
).triu_(diagonal=1) ).tril_(diagonal=0)
causal_mask = torch.zeros(
q_len, q_len, dtype=query.dtype, device=query.device
).masked_fill_(mask_bool, inf_value)
attn_out = torch_npu.npu_fusion_attention( attn_out = torch_npu.npu_fusion_attention(
q.unsqueeze(0), q.unsqueeze(0),