mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: use npu_fusion_attention loop (BSND) for prefill_no_cache to fix crash
This commit is contained in:
@@ -449,33 +449,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Prefill attention without KV cache (self-attention) using _npu_flash_attention."""
|
"""Prefill attention without KV cache (self-attention) via per-req loop."""
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
# Huawei uses _npu_flash_attention for prefill
|
query_start_loc = attn_metadata.query_start_loc
|
||||||
# Ensure contiguous inputs
|
seq_lens = attn_metadata.seq_lens
|
||||||
query = query.contiguous()
|
num_reqs = len(seq_lens)
|
||||||
key = key.contiguous()
|
|
||||||
value = value.contiguous()
|
|
||||||
|
|
||||||
# mask needs to be contiguous and cast to expected format if needed
|
# Iterate and process each request independently to bypass TND issues
|
||||||
# but _npu_flash_attention handles generic mask?
|
for i in range(num_reqs):
|
||||||
# Huawei code: mask = attn_metadata.attn_mask...
|
start = query_start_loc[i].item()
|
||||||
# We'll pass it as is, assuming AscendMetadataBuilder created it correctly.
|
end = query_start_loc[i + 1].item()
|
||||||
|
q_len = end - start
|
||||||
|
|
||||||
torch_npu._npu_flash_attention(
|
# Extract q, k, v (BSND)
|
||||||
query=query,
|
q = query[start:end].unsqueeze(0)
|
||||||
key=key,
|
k = key[start:end].unsqueeze(0)
|
||||||
value=value,
|
v = value[start:end].unsqueeze(0)
|
||||||
mask=attn_metadata.attn_mask,
|
|
||||||
seq_len=attn_metadata.seq_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
out=output
|
|
||||||
)
|
|
||||||
|
|
||||||
return output[:num_tokens]
|
# Mask (lower triangular for causal)
|
||||||
|
attn_mask = torch.ones(
|
||||||
|
q_len, q_len, dtype=torch.bool, device=query.device
|
||||||
|
).triu_(diagonal=1).unsqueeze(0)
|
||||||
|
|
||||||
|
# Run npu_fusion_attention (BSND)
|
||||||
|
attn_out = torch_npu.npu_fusion_attention(
|
||||||
|
q, k, v,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
input_layout="BSND",
|
||||||
|
scale=self.scale,
|
||||||
|
atten_mask=attn_mask,
|
||||||
|
pre_tockens=2147483647,
|
||||||
|
next_tockens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
output[start:end] = attn_out[0]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
|
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
|
||||||
|
|||||||
Reference in New Issue
Block a user