From 37af1ddc1ff26978f6638bdd7934e97680dc94fd Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 20:42:47 +0800 Subject: [PATCH] fix: use npu_fusion_attention loop (BSND) for prefill_no_cache to fix crash --- vllm_npu/attention/attention_v1.py | 58 +++++++++++++++++------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 10029a3..3669db6 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -449,33 +449,43 @@ class AscendAttentionBackendImpl(AttentionImpl): output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: - """Prefill attention without KV cache (self-attention) using _npu_flash_attention.""" + """Prefill attention without KV cache (self-attention) via per-req loop.""" import torch_npu # noqa: F401 - # Huawei uses _npu_flash_attention for prefill - # Ensure contiguous inputs - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # mask needs to be contiguous and cast to expected format if needed - # but _npu_flash_attention handles generic mask? - # Huawei code: mask = attn_metadata.attn_mask... - # We'll pass it as is, assuming AscendMetadataBuilder created it correctly. - - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output - ) + query_start_loc = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + num_reqs = len(seq_lens) - return output[:num_tokens] + # Iterate and process each request independently to bypass TND issues + for i in range(num_reqs): + start = query_start_loc[i].item() + end = query_start_loc[i + 1].item() + q_len = end - start + + # Extract q, k, v (BSND) + q = query[start:end].unsqueeze(0) + k = key[start:end].unsqueeze(0) + v = value[start:end].unsqueeze(0) + + # Mask (lower triangular for causal) + attn_mask = torch.ones( + q_len, q_len, dtype=torch.bool, device=query.device + ).triu_(diagonal=1).unsqueeze(0) + + # Run npu_fusion_attention (BSND) + attn_out = torch_npu.npu_fusion_attention( + q, k, v, + head_num=self.num_heads, + input_layout="BSND", + scale=self.scale, + atten_mask=attn_mask, + pre_tockens=2147483647, + next_tockens=0, + ) + + output[start:end] = attn_out[0] + + return output # ----------------------------------------------------------------- # Chunked prefill — mixed prefill+decode via npu_fusion_attention