mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
feat: Add Ascend NPU attention backend with NPU-specific FlashAttention, LayerNorm, and Rotary Embedding implementations.
This commit is contained in:
@@ -380,6 +380,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
# Step 1: Update KV cache
|
||||
if key is not None and value is not None:
|
||||
@@ -467,14 +469,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
k = key[start:end].unsqueeze(0)
|
||||
v = value[start:end].unsqueeze(0)
|
||||
|
||||
# Create boolean mask (Lower triangle=True means Keep, Upper=False means Mask)
|
||||
# npu_fusion_attention (sparse_mode=0) interprets True as Keep?
|
||||
# Or if True=Mask, then tril masks Past (Garbage).
|
||||
# But triu (Upper=True) produced Garbage.
|
||||
# So we try tril (Lower=True).
|
||||
# npu_fusion_attention: True = mask out (do NOT attend)
|
||||
# Upper triangle = future tokens = should be masked out
|
||||
attn_mask = torch.ones(
|
||||
q_len, q_len, dtype=torch.bool, device=query.device
|
||||
).tril_(diagonal=0).unsqueeze(0).unsqueeze(0)
|
||||
).triu_(diagonal=1).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Run npu_fusion_attention (BSND)
|
||||
attn_out = torch_npu.npu_fusion_attention(
|
||||
@@ -567,9 +566,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
-1, self.num_kv_heads, self.head_size
|
||||
)[:kv_len]
|
||||
|
||||
# npu_fusion_attention: True = mask out
|
||||
# For chunked prefill, mask future positions
|
||||
causal_mask = torch.ones(
|
||||
q_len, kv_len, dtype=torch.bool, device=query.device
|
||||
).tril_(diagonal=kv_len - q_len) # Adjusted for offset? Or just simple?
|
||||
).triu_(diagonal=kv_len - q_len + 1)
|
||||
# logic for chunked prefill mask (non-square)?
|
||||
# If q_len < kv_len (prefill extension), mask logic is complex.
|
||||
# Usually: mask[i, j] = True if j <= i + (kv_len - q_len).
|
||||
@@ -594,9 +595,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
k = key[start:end]
|
||||
v = value[start:end]
|
||||
|
||||
# npu_fusion_attention: True = mask out
|
||||
causal_mask = torch.ones(
|
||||
q_len, q_len, dtype=torch.bool, device=query.device
|
||||
).tril_(diagonal=0)
|
||||
).triu_(diagonal=1)
|
||||
|
||||
attn_out = torch_npu.npu_fusion_attention(
|
||||
q.unsqueeze(0),
|
||||
|
||||
Reference in New Issue
Block a user