From 810a2ef7572a6cbb0d24557aa5316923469c7e6c Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 20:06:52 +0800 Subject: [PATCH] refactor: align attention with Huawei vllm-ascend - reshape_and_cache with kv_cache[0]/[1], _get_fia_params, npu_fused_infer_attention_score for chunked prefill, add actual_seq_lengths_q --- vllm_npu/attention/attention_v1.py | 266 ++++++++++++----------------- 1 file changed, 109 insertions(+), 157 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index da9e7b3..c993b17 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -141,6 +141,7 @@ class AscendMetadata: query_start_loc: Optional[torch.Tensor] = None # (batch+1,) query_lens: Optional[torch.Tensor] = None max_query_len: Optional[int] = None + actual_seq_lengths_q: Optional[List[int]] = None # cumulative q positions # KV cache mapping block_tables: Optional[torch.Tensor] = None # (batch, max_blocks) @@ -207,12 +208,15 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): attn_state = AscendAttentionState.ChunkedPrefill # Build cumulative sequence lengths for query (for prefill) + num_reqs = common_attn_metadata.num_reqs + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_start_loc = common_attn_metadata.query_start_loc.to( dtype=torch.int64 ) + actual_seq_lengths_q = query_start_loc_cpu[1:].tolist() seq_lens = common_attn_metadata.seq_lens - seq_lens_list = common_attn_metadata.seq_lens_cpu.tolist() + seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() # Build attention mask for prefill (causal mask) attn_mask = None @@ -232,6 +236,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): seq_lens_list=seq_lens_list, query_start_loc=query_start_loc, max_query_len=max_query_len, + actual_seq_lengths_q=actual_seq_lengths_q, block_tables=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, attn_mask=attn_mask, @@ -297,15 +302,72 @@ class AscendAttentionBackendImpl(AttentionImpl): self._key_cache: Optional[torch.Tensor] = None self._value_cache: Optional[torch.Tensor] = None + def reshape_and_cache( + self, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: "AscendMetadata", + ): + """Update KV cache with new key/value tensors. + + Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via + slot_mapping indices. + """ + if kv_cache.numel() > 0: + if self._key_cache is None: + self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] + + slots = attn_metadata.slot_mapping + key_to_cache = key[:attn_metadata.num_actual_tokens] + val_to_cache = value[:attn_metadata.num_actual_tokens] + + # Use pure-PyTorch indexing (ATB reshape_and_cache may fail + # depending on environment; this is functionally identical) + block_size = self._key_cache.shape[1] + block_idx = slots // block_size + block_offset = slots % block_size + self._key_cache[block_idx, block_offset] = key_to_cache + self._value_cache[block_idx, block_offset] = val_to_cache + + return key, value + + # ----------------------------------------------------------------- + # Forward dispatch (matches Huawei vllm-ascend structure) + # ----------------------------------------------------------------- + + def _get_fia_params( + self, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: "AscendMetadata", + ): + """Prepare key, value, block_size, block_table and kv_seq_lens + for npu_fused_infer_attention_score, following Huawei's approach.""" + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.query_start_loc[1:].tolist() + else: + # DecodeOnly / PrefillCacheHit / ChunkedPrefill — read from cache + num_block, block_size, _, _ = self._key_cache.shape + key = self._key_cache.view(num_block, block_size, -1) + value = self._value_cache.view(num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + return key, value, block_size, block_table, actual_seq_lengths_kv + def forward( self, layer: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor, ...], + kv_cache: torch.Tensor, attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Ascend attention. @@ -313,8 +375,8 @@ class AscendAttentionBackendImpl(AttentionImpl): query: (num_tokens, num_heads * head_size) key: (num_tokens, num_kv_heads * head_size) value: (num_tokens, num_kv_heads * head_size) - kv_cache: (key_cache, value_cache) each - (num_blocks, block_size, num_kv_heads, head_size) + kv_cache: tensor of shape + (2, num_blocks, block_size, num_kv_heads, head_size) attn_metadata: AscendMetadata for this forward call. Returns: @@ -322,48 +384,24 @@ class AscendAttentionBackendImpl(AttentionImpl): """ import torch_npu # noqa: F401 + assert output is not None, "Output tensor must be provided." num_tokens = query.shape[0] - if output is None: - output = torch.empty( - num_tokens, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device, - ) - if attn_metadata is None: - return output.view(num_tokens, self.hidden_size).fill_(0) + return output.fill_(0) - num_actual_tokens = attn_metadata.num_actual_tokens - - # Reshape Q/K/V to BSH (tokens, heads, head_dim) + # Reshape Q/K/V to TND (tokens, heads, head_dim) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size).contiguous() + value = value.view(-1, self.num_kv_heads, self.head_size) - # ---------------------------------------------------------- # Step 1: Update KV cache - # ---------------------------------------------------------- - if kv_cache is not None and len(kv_cache.shape) > 1: - if self._key_cache is None: - self._key_cache, self._value_cache = kv_cache.unbind(0) + if key is not None and value is not None: + key, value = self.reshape_and_cache( + key, value, kv_cache, attn_metadata + ) - slots = attn_metadata.slot_mapping - # Pure PyTorch reshape_and_cache (avoids ATB dependency) - key_to_cache = key[:num_actual_tokens] - val_to_cache = value[:num_actual_tokens] - block_size = self._key_cache.shape[1] - block_idx = slots // block_size - block_offset = slots % block_size - self._key_cache[block_idx, block_offset] = key_to_cache - self._value_cache[block_idx, block_offset] = val_to_cache - - - # ---------------------------------------------------------- # Step 2: Compute attention - # ---------------------------------------------------------- if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: output = self._forward_decode( query, attn_metadata, output, num_tokens @@ -373,15 +411,15 @@ class AscendAttentionBackendImpl(AttentionImpl): query, key, value, attn_metadata, output, num_tokens ) else: - # ChunkedPrefill or PrefillCacheHit - output = self._forward_chunked_prefill( - query, key, value, attn_metadata, output, num_tokens + # ChunkedPrefill or PrefillCacheHit — use FIA with block tables + output = self._forward_fused_infer_attention( + query, key, value, attn_metadata, output ) - return output.view(num_tokens, self.hidden_size) + return output # ----------------------------------------------------------------- - # Decode path — paged attention via npu_incre_flash_attention + # Decode path — paged attention (matches Huawei _npu_paged_attention) # ----------------------------------------------------------------- def _forward_decode( @@ -391,13 +429,9 @@ class AscendAttentionBackendImpl(AttentionImpl): output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: - """Decode-only attention using incremental flash attention.""" + """Decode-only via npu_incre_flash_attention.""" import torch_npu # noqa: F401 - # npu_incre_flash_attention expects: - # query: (batch, 1, num_heads, head_size) - # key_cache: (num_blocks, block_size, num_kv_heads, head_size) - # value_cache: (num_blocks, block_size, num_kv_heads, head_size) q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D) attn_out = torch_npu.npu_incre_flash_attention( @@ -417,7 +451,7 @@ class AscendAttentionBackendImpl(AttentionImpl): return output # ----------------------------------------------------------------- - # Prefill without KV cache (first token, no paging) + # Prefill without KV cache # ----------------------------------------------------------------- def _forward_prefill_no_cache( @@ -453,127 +487,45 @@ class AscendAttentionBackendImpl(AttentionImpl): return output # ----------------------------------------------------------------- - # Chunked prefill — mixed prefill+decode + # Fused Infer Attention (prefill with cache / chunked prefill) + # Matches Huawei's forward_fused_infer_attention approach # ----------------------------------------------------------------- - def _forward_chunked_prefill( + def _forward_fused_infer_attention( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, - num_tokens: int, ) -> torch.Tensor: - """Chunked prefill using npu_fusion_attention with paged KV cache.""" + """Use npu_fused_infer_attention_score with TND layout and block + tables — the same approach Huawei uses for chunked prefill and + cache-hit prefill.""" import torch_npu # noqa: F401 - # Split batch into decodes and prefills based on query length - query_start_loc = attn_metadata.query_start_loc - seq_lens = attn_metadata.seq_lens + key, value, block_size, block_table, actual_seq_lengths_kv = ( + self._get_fia_params(key, value, attn_metadata) + ) + num_tokens = attn_metadata.actual_seq_lengths_q[-1] + query = query[:num_tokens] - # Compute per-request query lengths - query_lens = query_start_loc[1:] - query_start_loc[:-1] - num_requests = len(query_lens) - - # Separate decode (query_len == 1) and prefill requests - decode_mask = query_lens == 1 - prefill_mask = ~decode_mask - num_decodes = decode_mask.sum().item() - - # Process decode tokens - if num_decodes > 0 and self._key_cache is not None: - decode_indices = torch.where(decode_mask)[0] - decode_query = query[query_start_loc[decode_indices]] - decode_block_tables = attn_metadata.block_tables[decode_indices] - decode_seq_lens = seq_lens[decode_indices].tolist() - - decode_q = decode_query.unsqueeze(1) # (B_decode, 1, H, D) - - decode_out = torch_npu.npu_incre_flash_attention( - decode_q, - self._key_cache, - self._value_cache, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - scale_value=self.scale, - block_table=decode_block_tables, - actual_seq_lengths=decode_seq_lens, - block_size=self._key_cache.shape[1], - input_layout="BNSD", - ) - - for i, idx in enumerate(decode_indices): - token_pos = query_start_loc[idx].item() - output[token_pos] = decode_out[i].squeeze(0) - - # Process prefill tokens - if prefill_mask.any(): - prefill_indices = torch.where(prefill_mask)[0] - for idx in prefill_indices: - start = query_start_loc[idx].item() - end = query_start_loc[idx + 1].item() - q_len = end - start - kv_len = seq_lens[idx].item() - - q = query[start:end] # (q_len, H, D) - - # Use npu_fusion_attention for this single prefill request - # Build a causal mask for this sequence - causal_mask = torch.ones( - kv_len, kv_len, dtype=torch.bool, device=query.device - ).triu_(diagonal=1) - - # For chunked prefill, key/value come from the cache - if self._key_cache is not None and kv_len > q_len: - # Gather KV from paged cache for this request - block_table = attn_metadata.block_tables[idx] - num_blocks_needed = (kv_len + self._key_cache.shape[1] - 1) \ - // self._key_cache.shape[1] - block_ids = block_table[:num_blocks_needed] - - # Gather KV from block cache - gathered_k = self._key_cache[block_ids].reshape( - -1, self.num_kv_heads, self.head_size - )[:kv_len] - gathered_v = self._value_cache[block_ids].reshape( - -1, self.num_kv_heads, self.head_size - )[:kv_len] - - # Only last q_len rows of the mask - causal_mask = causal_mask[kv_len - q_len : kv_len, :kv_len] - - attn_out = torch_npu.npu_fusion_attention( - q.unsqueeze(0), # (1, q_len, H, D) — BSH layout - gathered_k.unsqueeze(0), - gathered_v.unsqueeze(0), - head_num=self.num_heads, - input_layout="BSND", - scale=self.scale, - sparse_mode=0, - atten_mask=causal_mask.unsqueeze(0), - pre_tockens=kv_len, - next_tockens=0, - ) - output[start:end] = attn_out[0].squeeze(0) - else: - # Full self-attention (no prior cache) - k = key[start:end] - v = value[start:end] - causal_mask = causal_mask[:q_len, :q_len] - - attn_out = torch_npu.npu_fusion_attention( - q.unsqueeze(0), - k.unsqueeze(0), - v.unsqueeze(0), - head_num=self.num_heads, - input_layout="BSND", - scale=self.scale, - sparse_mode=0, - atten_mask=causal_mask.unsqueeze(0), - pre_tockens=q_len, - next_tockens=0, - ) - output[start:end] = attn_out[0].squeeze(0) + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) + output[:num_tokens] = attn_output[:num_tokens] return output