From e22617f72e600072f8d5543cb3ffc0727641e6f4 Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 22:15:26 +0800 Subject: [PATCH] feat: Add Ascend NPU attention backend for vLLM using FlashAttention operators. --- vllm_npu/attention/attention_v1.py | 250 +++++++++-------------------- 1 file changed, 73 insertions(+), 177 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index ad0b9b8..9b6a3f3 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -5,9 +5,9 @@ Implements the ``AttentionBackend``, ``AttentionMetadata``, ``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using Huawei Ascend NPU FlashAttention operators: -- ``torch_npu.npu_fusion_attention`` — fused multi-head attention +- ``torch_npu._npu_flash_attention`` — prefill attention (TND layout) - ``torch_npu._npu_reshape_and_cache`` — KV cache update -- ``torch_npu.npu_incre_flash_attention`` — paged-attention decode +- ``torch_npu._npu_paged_attention`` — paged-attention decode """ from dataclasses import dataclass @@ -319,30 +319,27 @@ class AscendAttentionBackendImpl(AttentionImpl): ): """Update KV cache with new key/value tensors. - Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via - slot_mapping indices. + Uses ``torch_npu._npu_reshape_and_cache`` for efficient in-place + KV cache update, matching vllm-ascend reference. """ + import torch_npu # noqa: F401 + if kv_cache.numel() > 0: if self._key_cache is None: self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping - key_to_cache = key[:attn_metadata.num_actual_tokens] - val_to_cache = value[:attn_metadata.num_actual_tokens] - - # Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env) - block_size = self._key_cache.shape[1] - block_idx = slots // block_size - block_offset = slots % block_size - self._key_cache[block_idx, block_offset] = key_to_cache - self._value_cache[block_idx, block_offset] = val_to_cache + num_actual = attn_metadata.num_actual_tokens + torch_npu._npu_reshape_and_cache( + key=key[:num_actual], + value=value[:num_actual], + key_cache=self._key_cache, + value_cache=self._value_cache, + slot_indices=slots, + ) return key, value - # ----------------------------------------------------------------- - # Forward dispatch - # ----------------------------------------------------------------- - def forward( self, layer: nn.Module, @@ -399,15 +396,15 @@ class AscendAttentionBackendImpl(AttentionImpl): query, key, value, attn_metadata, output, num_tokens ) else: - # ChunkedPrefill or PrefillCacheHit + # ChunkedPrefill — use npu_fused_infer_attention_score output = self._forward_chunked_prefill( - query, key, value, attn_metadata, output, num_tokens + query, attn_metadata, output, num_tokens ) return output # ----------------------------------------------------------------- - # Decode path — paged attention (matches Huawei _npu_paged_attention) + # Decode path — paged attention via _npu_paged_attention # ----------------------------------------------------------------- def _forward_decode( @@ -417,29 +414,24 @@ class AscendAttentionBackendImpl(AttentionImpl): output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: - """Decode-only via npu_incre_flash_attention.""" + """Decode-only via _npu_paged_attention (matches vllm-ascend).""" import torch_npu # noqa: F401 - q = query[:num_tokens].unsqueeze(2) # (B, H, 1, D) for BNSD - - attn_out = torch_npu.npu_incre_flash_attention( - q, - self._key_cache, - self._value_cache, + torch_npu._npu_paged_attention( + query=query, + key_cache=self._key_cache, + value_cache=self._value_cache, + num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, scale_value=self.scale, block_table=attn_metadata.block_tables, - actual_seq_lengths=attn_metadata.seq_lens_list, - block_size=self._key_cache.shape[1], - input_layout="BNSD", + context_lens=attn_metadata.seq_lens, + out=output, ) - - output[:num_tokens] = attn_out.squeeze(2) return output # ----------------------------------------------------------------- - # Prefill without KV cache + # Prefill without KV cache — _npu_flash_attention (TND layout) # ----------------------------------------------------------------- def _forward_prefill_no_cache( @@ -451,168 +443,72 @@ class AscendAttentionBackendImpl(AttentionImpl): output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: - """Prefill attention without KV cache (self-attention) via per-req loop.""" + """Prefill attention without KV cache via _npu_flash_attention. + + Uses TND layout and a pre-built causal mask from metadata. + This matches vllm-ascend's _forward_prefill_no_cache. + """ import torch_npu # noqa: F401 - query_start_loc = attn_metadata.query_start_loc - seq_lens = attn_metadata.seq_lens - num_reqs = len(seq_lens) + mask = attn_metadata.attn_mask - # Iterate and process each request independently to bypass TND issues - for i in range(num_reqs): - start = query_start_loc[i].item() - end = query_start_loc[i + 1].item() - q_len = end - start - - # Extract q, k, v (BSND) - q = query[start:end].unsqueeze(0) - k = key[start:end].unsqueeze(0) - v = value[start:end].unsqueeze(0) - - # npu_fusion_attention: True = mask out (do NOT attend) - # Upper triangle = future tokens = should be masked out - attn_mask = torch.ones( - q_len, q_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1).unsqueeze(0).unsqueeze(0) - - # Run npu_fusion_attention (BSND) - attn_out = torch_npu.npu_fusion_attention( - q, k, v, - head_num=self.num_heads, - input_layout="BSND", - scale=self.scale, - atten_mask=attn_mask, - pre_tockens=2147483647, - next_tockens=0, - ) - - output[start:end] = attn_out[0] - - return output + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output, + ) + return output[:num_tokens, :, :] # ----------------------------------------------------------------- - # Chunked prefill — mixed prefill+decode via npu_fusion_attention - # (npu_fused_infer_attention_score requires 4D on older CANN) + # Chunked prefill — npu_fused_infer_attention_score (TND layout) # ----------------------------------------------------------------- def _forward_chunked_prefill( self, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: - """Chunked prefill: decode tokens via npu_incre_flash_attention, - prefill tokens via npu_fusion_attention per request.""" + """Chunked prefill / mixed prefill+decode via + npu_fused_infer_attention_score, matching vllm-ascend's + _forward_v1_style.""" import torch_npu # noqa: F401 - query_start_loc = attn_metadata.query_start_loc - seq_lens = attn_metadata.seq_lens + assert self._key_cache is not None + assert attn_metadata.attn_mask is not None - # Per-request query lengths - query_lens = query_start_loc[1:] - query_start_loc[:-1] + num_block, block_size, _, _ = self._key_cache.shape + key = self._key_cache.view(num_block, block_size, -1) + value = self._value_cache.view(num_block, block_size, -1) - decode_mask = query_lens == 1 - prefill_mask = ~decode_mask - num_decodes = decode_mask.sum().item() + # Trim query to actual tokens (npu_fused_infer_attention_score + # requires query.shape[0] == query_start_loc[-1]) + actual_num_tokens = attn_metadata.query_start_loc[-1] + q = query[:actual_num_tokens] - # --- Decode tokens --- - if num_decodes > 0 and self._key_cache is not None: - decode_indices = torch.where(decode_mask)[0] - decode_query = query[query_start_loc[decode_indices]] - decode_block_tables = attn_metadata.block_tables[decode_indices] - decode_seq_lens = seq_lens[decode_indices].tolist() - - decode_out = torch_npu.npu_incre_flash_attention( - decode_query.unsqueeze(2), - self._key_cache, - self._value_cache, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - scale_value=self.scale, - block_table=decode_block_tables, - actual_seq_lengths=decode_seq_lens, - block_size=self._key_cache.shape[1], - input_layout="BNSD", - ) - - for i, idx in enumerate(decode_indices): - token_pos = query_start_loc[idx].item() - output[token_pos] = decode_out[i].squeeze(1) - - # --- Prefill tokens (per-request via npu_fusion_attention) --- - if prefill_mask.any(): - prefill_indices = torch.where(prefill_mask)[0] - for idx in prefill_indices: - start = query_start_loc[idx].item() - end = query_start_loc[idx + 1].item() - q_len = end - start - kv_len = seq_lens[idx].item() - q = query[start:end] - - if self._key_cache is not None and kv_len > q_len: - # Gather KV from paged cache - block_table = attn_metadata.block_tables[idx] - bs = self._key_cache.shape[1] - num_blocks_needed = (kv_len + bs - 1) // bs - block_ids = block_table[:num_blocks_needed] - - gathered_k = self._key_cache[block_ids].reshape( - -1, self.num_kv_heads, self.head_size - )[:kv_len] - gathered_v = self._value_cache[block_ids].reshape( - -1, self.num_kv_heads, self.head_size - )[:kv_len] - - # npu_fusion_attention: True = mask out - # For chunked prefill, mask future positions - causal_mask = torch.ones( - q_len, kv_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=kv_len - q_len + 1) - # logic for chunked prefill mask (non-square)? - # If q_len < kv_len (prefill extension), mask logic is complex. - # Usually: mask[i, j] = True if j <= i + (kv_len - q_len). - # tril with diagonal adjustment. - # diagonal=kv_len - q_len ensures main diagonal alignment. - - attn_out = torch_npu.npu_fusion_attention( - q.unsqueeze(0), - gathered_k.unsqueeze(0), - gathered_v.unsqueeze(0), - head_num=self.num_heads, - input_layout="BSND", - scale=self.scale, - sparse_mode=0, - atten_mask=causal_mask.unsqueeze(0).unsqueeze(0), - pre_tockens=kv_len, - next_tockens=0, - ) - output[start:end] = attn_out[0].squeeze(0) - else: - # Full self-attention (no prior cache) - k = key[start:end] - v = value[start:end] - - # npu_fusion_attention: True = mask out - causal_mask = torch.ones( - q_len, q_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1) - - attn_out = torch_npu.npu_fusion_attention( - q.unsqueeze(0), - k.unsqueeze(0), - v.unsqueeze(0), - head_num=self.num_heads, - input_layout="BSND", - scale=self.scale, - sparse_mode=0, - atten_mask=causal_mask.unsqueeze(0).unsqueeze(0), - pre_tockens=q_len, - next_tockens=0, - ) - output[start:end] = attn_out[0].squeeze(0) + out, _ = torch_npu.npu_fused_infer_attention_score( + query=q, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=attn_metadata.seq_lens_list, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + output[:actual_num_tokens, :, :] = out[:actual_num_tokens, :, :] return output