""" Ascend NPU attention backend for vLLM v1. Implements the ``AttentionBackend``, ``AttentionMetadata``, ``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using Huawei Ascend NPU FlashAttention operators: - ``torch_npu.npu_fusion_attention`` — fused multi-head attention - ``torch_npu._npu_reshape_and_cache`` — KV cache update - ``torch_npu.npu_incre_flash_attention`` — paged-attention decode """ from dataclasses import dataclass from enum import IntEnum from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple import torch import torch.nn as nn from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionType, ) from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) # ===================================================================== # Attention state enum # ===================================================================== class AscendAttentionState(IntEnum): """Attention computation state, determines the kernel path.""" PrefillNoCache = 0 PrefillCacheHit = 1 DecodeOnly = 2 ChunkedPrefill = 3 # ===================================================================== # Backend class # ===================================================================== class AscendAttentionBackend(AttentionBackend): """Ascend NPU FlashAttention backend.""" accept_output_buffer: bool = True @staticmethod def get_name() -> str: return "ASCEND_ATTN" @staticmethod def get_impl_cls() -> type["AttentionImpl"]: return AscendAttentionBackendImpl @staticmethod def get_metadata_cls() -> type["AscendMetadata"]: return AscendMetadata @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: return AscendAttentionMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, **kwargs, ) -> Tuple[int, ...]: """KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size). The leading ``2`` stores key and value caches in a single tensor. They are split via ``kv_cache.unbind(0)`` at runtime. """ return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: List[torch.Tensor], dst_kv_cache: List[torch.Tensor], src_to_dst: torch.Tensor, ) -> None: """Swap KV cache blocks between src and dst.""" src_key_cache, src_value_cache = src_kv_cache dst_key_cache, dst_value_cache = dst_kv_cache for src_idx, dst_idx in src_to_dst.tolist(): dst_key_cache[dst_idx].copy_(src_key_cache[src_idx]) dst_value_cache[dst_idx].copy_(src_value_cache[src_idx]) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dsts: torch.Tensor, ) -> None: """Copy KV cache blocks in-place.""" key_caches = [kv[0] for kv in kv_caches] value_caches = [kv[1] for kv in kv_caches] for src_idx, dst_idx in src_to_dsts.tolist(): for key_cache in key_caches: key_cache[dst_idx].copy_(key_cache[src_idx]) for value_cache in value_caches: value_cache[dst_idx].copy_(value_cache[src_idx]) # ===================================================================== # Metadata dataclass # ===================================================================== @dataclass class AscendMetadata: """Per-layer attention metadata for the Ascend backend.""" attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill num_actual_tokens: int = 0 # Sequence lengths and query positions seq_lens: Optional[torch.Tensor] = None # (batch,) seq_lens_list: Optional[List[int]] = None query_start_loc: Optional[torch.Tensor] = None # (batch+1,) query_lens: Optional[torch.Tensor] = None max_query_len: Optional[int] = None actual_seq_lengths_q: Optional[List[int]] = None # cumulative q positions # KV cache mapping block_tables: Optional[torch.Tensor] = None # (batch, max_blocks) slot_mapping: Optional[torch.Tensor] = None # (num_tokens,) # Attention mask (for prefill causal masking) attn_mask: Optional[torch.Tensor] = None # ===================================================================== # Metadata builder # ===================================================================== class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): """Builds ``AscendMetadata`` from ``CommonAttentionMetadata``.""" cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) reorder_batch_threshold: ClassVar[int] = 1 def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: "VllmConfig", device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size self.num_kv_heads = kv_cache_spec.num_kv_heads self.head_size = kv_cache_spec.head_size def reorder_batch( self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput", ) -> bool: """ Reorder so decodes (query_len == 1) come first, prefills after. """ from vllm.v1.attention.backends.utils import ( reorder_batch_to_split_decodes_and_prefills, ) return reorder_batch_to_split_decodes_and_prefills( input_batch, scheduler_output, decode_threshold=1 ) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> AscendMetadata: """Build AscendMetadata from the common attention metadata.""" num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len # Determine attention state num_reqs = common_attn_metadata.num_reqs if max_query_len == 1: attn_state = AscendAttentionState.DecodeOnly else: # Check if this is a pure prefill (no prior cache) or chunked query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs] # PrefillNoCache: all requests have query_len == seq_len if (query_lens_cpu == seq_lens_cpu).all(): attn_state = AscendAttentionState.PrefillNoCache else: attn_state = AscendAttentionState.ChunkedPrefill # Build cumulative sequence lengths for query (for prefill) query_start_loc_cpu_full = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_start_loc = common_attn_metadata.query_start_loc.to( dtype=torch.int64 ) actual_seq_lengths_q = query_start_loc_cpu_full[1:].tolist() seq_lens = common_attn_metadata.seq_lens seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() # Build attention mask for prefill (causal mask) attn_mask = None if attn_state != AscendAttentionState.DecodeOnly: max_seq = common_attn_metadata.max_seq_len attn_mask = torch.ones( max_seq, max_seq, dtype=torch.bool, device=self.device, ).triu_(diagonal=1) return AscendMetadata( attn_state=attn_state, num_actual_tokens=num_actual_tokens, seq_lens=seq_lens, seq_lens_list=seq_lens_list, query_start_loc=query_start_loc, max_query_len=max_query_len, actual_seq_lengths_q=actual_seq_lengths_q, block_tables=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, attn_mask=attn_mask, ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata, ) -> AscendMetadata: """Build metadata for graph capture (decode-only).""" return self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) # ===================================================================== # Attention implementation # ===================================================================== class AscendAttentionBackendImpl(AttentionImpl): """ Ascend NPU attention kernel implementation. Uses ``torch_npu.npu_fusion_attention`` for prefill and ``torch_npu.npu_incre_flash_attention`` for decode. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.hidden_size = self.num_heads * self.head_size self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window self.attn_type = attn_type if alibi_slopes is not None: alibi_slopes = torch.tensor( alibi_slopes, dtype=torch.float32, device="npu" ) self.alibi_slopes = alibi_slopes assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads # Cached references to the KV cache tensors self._key_cache: Optional[torch.Tensor] = None self._value_cache: Optional[torch.Tensor] = None def reshape_and_cache( self, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: "AscendMetadata", ): """Update KV cache with new key/value tensors. Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via slot_mapping indices. """ if kv_cache.numel() > 0: if self._key_cache is None: self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping key_to_cache = key[:attn_metadata.num_actual_tokens] val_to_cache = value[:attn_metadata.num_actual_tokens] # Use pure-PyTorch indexing (ATB reshape_and_cache may fail # depending on environment; this is functionally identical) block_size = self._key_cache.shape[1] block_idx = slots // block_size block_offset = slots % block_size self._key_cache[block_idx, block_offset] = key_to_cache self._value_cache[block_idx, block_offset] = val_to_cache return key, value # ----------------------------------------------------------------- # Forward dispatch # ----------------------------------------------------------------- def forward( self, layer: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: query: (num_tokens, num_heads * head_size) key: (num_tokens, num_kv_heads * head_size) value: (num_tokens, num_kv_heads * head_size) kv_cache: tensor of shape (2, num_blocks, block_size, num_kv_heads, head_size) attn_metadata: AscendMetadata for this forward call. Returns: (num_tokens, num_heads * head_size) """ import torch_npu # noqa: F401 assert output is not None, "Output tensor must be provided." num_tokens = query.shape[0] if attn_metadata is None: return output.fill_(0) # Reshape Q/K/V to TND (tokens, heads, head_dim) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) # Step 1: Update KV cache if key is not None and value is not None: key, value = self.reshape_and_cache( key, value, kv_cache, attn_metadata ) # Step 2: Compute attention if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: output = self._forward_decode( query, attn_metadata, output, num_tokens ) elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: output = self._forward_prefill_no_cache( query, key, value, attn_metadata, output, num_tokens ) else: # ChunkedPrefill or PrefillCacheHit output = self._forward_chunked_prefill( query, key, value, attn_metadata, output, num_tokens ) return output # ----------------------------------------------------------------- # Decode path — paged attention (matches Huawei _npu_paged_attention) # ----------------------------------------------------------------- def _forward_decode( self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: """Decode-only via npu_incre_flash_attention.""" import torch_npu # noqa: F401 q = query[:num_tokens].unsqueeze(1) # (B, 1, H, D) attn_out = torch_npu.npu_incre_flash_attention( q, self._key_cache, self._value_cache, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, scale_value=self.scale, block_table=attn_metadata.block_tables, actual_seq_lengths=attn_metadata.seq_lens_list, block_size=self._key_cache.shape[1], input_layout="BNSD", ) output[:num_tokens] = attn_out.squeeze(1) return output # ----------------------------------------------------------------- # Prefill without KV cache # ----------------------------------------------------------------- def _forward_prefill_no_cache( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: """Prefill attention without KV cache (self-attention).""" import torch_npu # noqa: F401 cum_seq_len = attn_metadata.query_start_loc[1:].tolist() attn_out = torch_npu.npu_fusion_attention( query[:num_tokens], key[:num_tokens], value[:num_tokens], head_num=self.num_heads, input_layout="TND", scale=self.scale, sparse_mode=0, atten_mask=attn_metadata.attn_mask, pre_tockens=2147483647, next_tockens=0, actual_seq_qlen=cum_seq_len, actual_seq_kvlen=cum_seq_len, ) output[:num_tokens] = attn_out[0] return output # ----------------------------------------------------------------- # Chunked prefill — mixed prefill+decode via npu_fusion_attention # (npu_fused_infer_attention_score requires 4D on older CANN) # ----------------------------------------------------------------- def _forward_chunked_prefill( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens: int, ) -> torch.Tensor: """Chunked prefill: decode tokens via npu_incre_flash_attention, prefill tokens via npu_fusion_attention per request.""" import torch_npu # noqa: F401 query_start_loc = attn_metadata.query_start_loc seq_lens = attn_metadata.seq_lens # Per-request query lengths query_lens = query_start_loc[1:] - query_start_loc[:-1] decode_mask = query_lens == 1 prefill_mask = ~decode_mask num_decodes = decode_mask.sum().item() # --- Decode tokens --- if num_decodes > 0 and self._key_cache is not None: decode_indices = torch.where(decode_mask)[0] decode_query = query[query_start_loc[decode_indices]] decode_block_tables = attn_metadata.block_tables[decode_indices] decode_seq_lens = seq_lens[decode_indices].tolist() decode_out = torch_npu.npu_incre_flash_attention( decode_query.unsqueeze(1), self._key_cache, self._value_cache, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, scale_value=self.scale, block_table=decode_block_tables, actual_seq_lengths=decode_seq_lens, block_size=self._key_cache.shape[1], input_layout="BNSD", ) for i, idx in enumerate(decode_indices): token_pos = query_start_loc[idx].item() output[token_pos] = decode_out[i].squeeze(0) # --- Prefill tokens (per-request via npu_fusion_attention) --- if prefill_mask.any(): prefill_indices = torch.where(prefill_mask)[0] for idx in prefill_indices: start = query_start_loc[idx].item() end = query_start_loc[idx + 1].item() q_len = end - start kv_len = seq_lens[idx].item() q = query[start:end] if self._key_cache is not None and kv_len > q_len: # Gather KV from paged cache block_table = attn_metadata.block_tables[idx] bs = self._key_cache.shape[1] num_blocks_needed = (kv_len + bs - 1) // bs block_ids = block_table[:num_blocks_needed] gathered_k = self._key_cache[block_ids].reshape( -1, self.num_kv_heads, self.head_size )[:kv_len] gathered_v = self._value_cache[block_ids].reshape( -1, self.num_kv_heads, self.head_size )[:kv_len] causal_mask = torch.ones( q_len, kv_len, dtype=torch.bool, device=query.device ).triu_(diagonal=kv_len - q_len + 1) attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0), gathered_k.unsqueeze(0), gathered_v.unsqueeze(0), head_num=self.num_heads, input_layout="BSND", scale=self.scale, sparse_mode=0, atten_mask=causal_mask.unsqueeze(0), pre_tockens=kv_len, next_tockens=0, ) output[start:end] = attn_out[0].squeeze(0) else: # Full self-attention (no prior cache) k = key[start:end] v = value[start:end] causal_mask = torch.ones( q_len, q_len, dtype=torch.bool, device=query.device ).triu_(diagonal=1) attn_out = torch_npu.npu_fusion_attention( q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), head_num=self.num_heads, input_layout="BSND", scale=self.scale, sparse_mode=0, atten_mask=causal_mask.unsqueeze(0), pre_tockens=q_len, next_tockens=0, ) output[start:end] = attn_out[0].squeeze(0) return output