""" Ascend NPU attention backend for vLLM v1. Implements the ``AttentionBackend``, ``AttentionMetadata``, ``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using Huawei Ascend NPU FlashAttention operators: - ``torch_npu._npu_flash_attention`` — prefill attention (TND layout) - ``torch_npu._npu_reshape_and_cache`` — KV cache update - ``torch_npu._npu_paged_attention`` — paged-attention decode """ from dataclasses import dataclass from enum import IntEnum from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple import torch import torch.nn as nn from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionType, ) from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) # ===================================================================== # Attention state enum # ===================================================================== class AscendAttentionState(IntEnum): """Attention computation state, determines the kernel path.""" PrefillNoCache = 0 PrefillCacheHit = 1 DecodeOnly = 2 ChunkedPrefill = 3 # ===================================================================== # Backend class # ===================================================================== class AscendAttentionBackend(AttentionBackend): """Ascend NPU FlashAttention backend.""" accept_output_buffer: bool = True @staticmethod def get_name() -> str: return "ASCEND_ATTN" @staticmethod def get_impl_cls() -> type["AttentionImpl"]: return AscendAttentionBackendImpl @staticmethod def get_metadata_cls() -> type["AscendMetadata"]: return AscendMetadata @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: return AscendAttentionMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, **kwargs, ) -> Tuple[int, ...]: """KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size). The leading ``2`` stores key and value caches in a single tensor. They are split via ``kv_cache.unbind(0)`` at runtime. """ return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: List[torch.Tensor], dst_kv_cache: List[torch.Tensor], src_to_dst: torch.Tensor, ) -> None: """Swap KV cache blocks between src and dst.""" src_key_cache, src_value_cache = src_kv_cache dst_key_cache, dst_value_cache = dst_kv_cache for src_idx, dst_idx in src_to_dst.tolist(): dst_key_cache[dst_idx].copy_(src_key_cache[src_idx]) dst_value_cache[dst_idx].copy_(src_value_cache[src_idx]) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dsts: torch.Tensor, ) -> None: """Copy KV cache blocks in-place.""" key_caches = [kv[0] for kv in kv_caches] value_caches = [kv[1] for kv in kv_caches] for src_idx, dst_idx in src_to_dsts.tolist(): for key_cache in key_caches: key_cache[dst_idx].copy_(key_cache[src_idx]) for value_cache in value_caches: value_cache[dst_idx].copy_(value_cache[src_idx]) # ===================================================================== # Metadata dataclass # ===================================================================== @dataclass class AscendMetadata: """Per-layer attention metadata for the Ascend backend.""" attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill num_actual_tokens: int = 0 # Sequence lengths and query positions seq_lens: Optional[torch.Tensor] = None # (batch,) seq_lens_list: Optional[List[int]] = None query_start_loc: Optional[torch.Tensor] = None # (batch+1,) query_lens: Optional[torch.Tensor] = None max_query_len: Optional[int] = None actual_seq_lengths_q: Optional[List[int]] = None # cumulative q positions # KV cache mapping block_tables: Optional[torch.Tensor] = None # (batch, max_blocks) slot_mapping: Optional[torch.Tensor] = None # (num_tokens,) # Attention mask (for prefill causal masking) attn_mask: Optional[torch.Tensor] = None # ===================================================================== # Metadata builder # ===================================================================== class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): """Builds ``AscendMetadata`` from ``CommonAttentionMetadata``.""" cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) reorder_batch_threshold: ClassVar[int] = 1 def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: "VllmConfig", device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size self.num_kv_heads = kv_cache_spec.num_kv_heads self.head_size = kv_cache_spec.head_size def reorder_batch( self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput", ) -> bool: """ Reorder so decodes (query_len == 1) come first, prefills after. """ from vllm.v1.attention.backends.utils import ( reorder_batch_to_split_decodes_and_prefills, ) return reorder_batch_to_split_decodes_and_prefills( input_batch, scheduler_output, decode_threshold=1 ) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> AscendMetadata: """Build AscendMetadata from the common attention metadata.""" num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len # Determine attention state num_reqs = common_attn_metadata.num_reqs if max_query_len == 1: attn_state = AscendAttentionState.DecodeOnly else: # Check if this is a pure prefill (no prior cache) or chunked query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs] # PrefillNoCache: all requests have query_len == seq_len if (query_lens_cpu == seq_lens_cpu).all(): attn_state = AscendAttentionState.PrefillNoCache else: attn_state = AscendAttentionState.ChunkedPrefill # Build cumulative sequence lengths for query (for prefill) query_start_loc_cpu_full = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_start_loc = common_attn_metadata.query_start_loc.to( dtype=torch.int64 ) actual_seq_lengths_q = query_start_loc_cpu_full[1:].tolist() seq_lens = common_attn_metadata.seq_lens seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() # Build attention mask for prefill (causal mask) attn_mask = None if attn_state != AscendAttentionState.DecodeOnly: max_seq = common_attn_metadata.max_seq_len attn_mask = torch.ones( max_seq, max_seq, dtype=torch.bool, device=self.device, ).triu_(diagonal=1) return AscendMetadata( attn_state=attn_state, num_actual_tokens=num_actual_tokens, seq_lens=seq_lens, seq_lens_list=seq_lens_list, query_start_loc=query_start_loc, max_query_len=max_query_len, actual_seq_lengths_q=actual_seq_lengths_q, block_tables=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, attn_mask=attn_mask, ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata, ) -> AscendMetadata: """Build metadata for graph capture (decode-only).""" return self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) # ===================================================================== # Attention implementation # ===================================================================== class AscendAttentionBackendImpl(AttentionImpl): """ Ascend NPU attention kernel implementation. Uses ``torch_npu.npu_fusion_attention`` for prefill and ``torch_npu.npu_incre_flash_attention`` for decode. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.hidden_size = self.num_heads * self.head_size self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window self.attn_type = attn_type if alibi_slopes is not None: alibi_slopes = torch.tensor( alibi_slopes, dtype=torch.float32, device="npu" ) self.alibi_slopes = alibi_slopes assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads # Cached references to the KV cache tensors self._key_cache: Optional[torch.Tensor] = None self._value_cache: Optional[torch.Tensor] = None def reshape_and_cache( self, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: "AscendMetadata", ): """Update KV cache with new key/value tensors. Uses ``torch_npu._npu_reshape_and_cache`` for efficient in-place KV cache update, matching vllm-ascend reference. """ import torch_npu # noqa: F401 if kv_cache.numel() > 0: if self._key_cache is None: self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping num_actual = attn_metadata.num_actual_tokens torch_npu._npu_reshape_and_cache( key=key[:num_actual], value=value[:num_actual], key_cache=self._key_cache, value_cache=self._value_cache, slot_indices=slots, ) return key, value def forward( self, layer: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: query: (num_tokens, num_heads * head_size) key: (num_tokens, num_kv_heads * head_size) value: (num_tokens, num_kv_heads * head_size) kv_cache: tensor of shape (2, num_blocks, block_size, num_kv_heads, head_size) attn_metadata: AscendMetadata for this forward call. Returns: (num_tokens, num_heads * head_size) """ import torch_npu # noqa: F401 assert output is not None, "Output tensor must be provided." num_tokens = query.shape[0] if attn_metadata is None: return output.fill_(0) # Reshape Q/K/V to TND (tokens, heads, head_dim) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) # TODO: Remove this contiguous in the future. value = value.contiguous() # Step 1: Update KV cache if key is not None and value is not None: key, value = self.reshape_and_cache( key, value, kv_cache, attn_metadata ) # Step 2: Compute attention if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: output = self._forward_decode( query, attn_metadata, output, num_tokens ) elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: output = self._forward_prefill_no_cache( query, key, value, attn_metadata, output, num_tokens ) else: # ChunkedPrefill — use npu_fused_infer_attention_score output = self._forward_chunked_prefill( query, attn_metadata, output, num_tokens ) return output # ----------------------------------------------------------------- # Decode path — paged attention via _npu_paged_attention # ----------------------------------------------------------------- def _forward_decode( self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: """Decode-only via _npu_paged_attention (matches vllm-ascend).""" import torch_npu # noqa: F401 torch_npu._npu_paged_attention( query=query, key_cache=self._key_cache, value_cache=self._value_cache, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, scale_value=self.scale, block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output, ) return output # ----------------------------------------------------------------- # Prefill without KV cache — _npu_flash_attention (TND layout) # ----------------------------------------------------------------- def _forward_prefill_no_cache( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: """Prefill attention without KV cache via _npu_flash_attention. Uses TND layout and a pre-built causal mask from metadata. This matches vllm-ascend's _forward_prefill_no_cache. """ import torch_npu # noqa: F401 mask = attn_metadata.attn_mask torch_npu._npu_flash_attention( query=query, key=key, value=value, mask=mask, seq_len=attn_metadata.seq_lens, scale_value=self.scale, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, out=output, ) return output[:num_tokens, :, :] # ----------------------------------------------------------------- # Chunked prefill — npu_fused_infer_attention_score (TND layout) # ----------------------------------------------------------------- def _forward_chunked_prefill( self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: """Chunked prefill / mixed prefill+decode via npu_fused_infer_attention_score, matching vllm-ascend's _forward_v1_style.""" import torch_npu # noqa: F401 assert self._key_cache is not None assert attn_metadata.attn_mask is not None num_block, block_size, _, _ = self._key_cache.shape key = self._key_cache.view(num_block, block_size, -1) value = self._value_cache.view(num_block, block_size, -1) # Trim query to actual tokens (npu_fused_infer_attention_score # requires query.shape[0] == query_start_loc[-1]) actual_num_tokens = attn_metadata.query_start_loc[-1] q = query[:actual_num_tokens] out, _ = torch_npu.npu_fused_infer_attention_score( query=q, key=key, value=value, atten_mask=attn_metadata.attn_mask, block_table=attn_metadata.block_tables, input_layout="TND", block_size=block_size, actual_seq_lengths=attn_metadata.actual_seq_lengths_q, actual_seq_lengths_kv=attn_metadata.seq_lens_list, num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, scale=self.scale, sparse_mode=3, ) output[:actual_num_tokens, :, :] = out[:actual_num_tokens, :, :] return output