Files
vllm-npu-plugin/vllm_npu/attention/attention_v1.py

515 lines
17 KiB
Python

"""
Ascend NPU attention backend for vLLM v1.
Implements the ``AttentionBackend``, ``AttentionMetadata``,
``AttentionMetadataBuilder``, and ``AttentionImpl`` interfaces using
Huawei Ascend NPU FlashAttention operators:
- ``torch_npu._npu_flash_attention`` — prefill attention (TND layout)
- ``torch_npu._npu_reshape_and_cache`` — KV cache update
- ``torch_npu._npu_paged_attention`` — paged-attention decode
"""
from dataclasses import dataclass
from enum import IntEnum
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
)
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
# =====================================================================
# Attention state enum
# =====================================================================
class AscendAttentionState(IntEnum):
"""Attention computation state, determines the kernel path."""
PrefillNoCache = 0
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
# =====================================================================
# Backend class
# =====================================================================
class AscendAttentionBackend(AttentionBackend):
"""Ascend NPU FlashAttention backend."""
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_ATTN"
@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
return AscendAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["AscendMetadata"]:
return AscendMetadata
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
return AscendAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
**kwargs,
) -> Tuple[int, ...]:
"""KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size).
The leading ``2`` stores key and value caches in a single tensor.
They are split via ``kv_cache.unbind(0)`` at runtime.
"""
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
"""Swap KV cache blocks between src and dst."""
src_key_cache, src_value_cache = src_kv_cache
dst_key_cache, dst_value_cache = dst_kv_cache
for src_idx, dst_idx in src_to_dst.tolist():
dst_key_cache[dst_idx].copy_(src_key_cache[src_idx])
dst_value_cache[dst_idx].copy_(src_value_cache[src_idx])
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor,
) -> None:
"""Copy KV cache blocks in-place."""
key_caches = [kv[0] for kv in kv_caches]
value_caches = [kv[1] for kv in kv_caches]
for src_idx, dst_idx in src_to_dsts.tolist():
for key_cache in key_caches:
key_cache[dst_idx].copy_(key_cache[src_idx])
for value_cache in value_caches:
value_cache[dst_idx].copy_(value_cache[src_idx])
# =====================================================================
# Metadata dataclass
# =====================================================================
@dataclass
class AscendMetadata:
"""Per-layer attention metadata for the Ascend backend."""
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
num_actual_tokens: int = 0
# Sequence lengths and query positions
seq_lens: Optional[torch.Tensor] = None # (batch,)
seq_lens_list: Optional[List[int]] = None
query_start_loc: Optional[torch.Tensor] = None # (batch+1,)
query_lens: Optional[torch.Tensor] = None
max_query_len: Optional[int] = None
actual_seq_lengths_q: Optional[List[int]] = None # cumulative q positions
# KV cache mapping
block_tables: Optional[torch.Tensor] = None # (batch, max_blocks)
slot_mapping: Optional[torch.Tensor] = None # (num_tokens,)
# Attention mask (for prefill causal masking)
attn_mask: Optional[torch.Tensor] = None
# =====================================================================
# Metadata builder
# =====================================================================
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
"""Builds ``AscendMetadata`` from ``CommonAttentionMetadata``."""
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: "VllmConfig",
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
self.num_kv_heads = kv_cache_spec.num_kv_heads
self.head_size = kv_cache_spec.head_size
def reorder_batch(
self,
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
) -> bool:
"""
Reorder so decodes (query_len == 1) come first, prefills after.
"""
from vllm.v1.attention.backends.utils import (
reorder_batch_to_split_decodes_and_prefills,
)
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=1
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AscendMetadata:
"""Build AscendMetadata from the common attention metadata."""
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
# Determine attention state
num_reqs = common_attn_metadata.num_reqs
if max_query_len == 1:
attn_state = AscendAttentionState.DecodeOnly
else:
# Check if this is a pure prefill (no prior cache) or chunked
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs]
# PrefillNoCache: all requests have query_len == seq_len
if (query_lens_cpu == seq_lens_cpu).all():
attn_state = AscendAttentionState.PrefillNoCache
else:
attn_state = AscendAttentionState.ChunkedPrefill
# Build cumulative sequence lengths for query (for prefill)
query_start_loc_cpu_full = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc.to(
dtype=torch.int64
)
actual_seq_lengths_q = query_start_loc_cpu_full[1:].tolist()
seq_lens = common_attn_metadata.seq_lens
seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist()
# Build attention mask for prefill (causal mask)
attn_mask = None
if attn_state != AscendAttentionState.DecodeOnly:
max_seq = common_attn_metadata.max_seq_len
attn_mask = torch.ones(
max_seq,
max_seq,
dtype=torch.bool,
device=self.device,
).triu_(diagonal=1)
return AscendMetadata(
attn_state=attn_state,
num_actual_tokens=num_actual_tokens,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
query_start_loc=query_start_loc,
max_query_len=max_query_len,
actual_seq_lengths_q=actual_seq_lengths_q,
block_tables=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
attn_mask=attn_mask,
)
def build_for_cudagraph_capture(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> AscendMetadata:
"""Build metadata for graph capture (decode-only)."""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# =====================================================================
# Attention implementation
# =====================================================================
class AscendAttentionBackendImpl(AttentionImpl):
"""
Ascend NPU attention kernel implementation.
Uses ``torch_npu.npu_fusion_attention`` for prefill and
``torch_npu.npu_incre_flash_attention`` for decode.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
self.attn_type = attn_type
if alibi_slopes is not None:
alibi_slopes = torch.tensor(
alibi_slopes, dtype=torch.float32, device="npu"
)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# Cached references to the KV cache tensors
self._key_cache: Optional[torch.Tensor] = None
self._value_cache: Optional[torch.Tensor] = None
def reshape_and_cache(
self,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: "AscendMetadata",
):
"""Update KV cache with new key/value tensors.
Uses ``torch_npu._npu_reshape_and_cache`` for efficient in-place
KV cache update, matching vllm-ascend reference.
"""
import torch_npu # noqa: F401
if kv_cache.numel() > 0:
if self._key_cache is None:
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
num_actual = attn_metadata.num_actual_tokens
torch_npu._npu_reshape_and_cache(
key=key[:num_actual],
value=value[:num_actual],
key_cache=self._key_cache,
value_cache=self._value_cache,
slot_indices=slots,
)
return key, value
def forward(
self,
layer: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: (num_tokens, num_heads * head_size)
key: (num_tokens, num_kv_heads * head_size)
value: (num_tokens, num_kv_heads * head_size)
kv_cache: tensor of shape
(2, num_blocks, block_size, num_kv_heads, head_size)
attn_metadata: AscendMetadata for this forward call.
Returns:
(num_tokens, num_heads * head_size)
"""
import torch_npu # noqa: F401
assert output is not None, "Output tensor must be provided."
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
# Reshape Q/K/V to TND (tokens, heads, head_dim)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
# Step 1: Update KV cache
if key is not None and value is not None:
key, value = self.reshape_and_cache(
key, value, kv_cache, attn_metadata
)
# Step 2: Compute attention
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
output = self._forward_decode(
query, attn_metadata, output, num_tokens
)
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens
)
else:
# ChunkedPrefill — use npu_fused_infer_attention_score
output = self._forward_chunked_prefill(
query, attn_metadata, output, num_tokens
)
return output
# -----------------------------------------------------------------
# Decode path — paged attention via _npu_paged_attention
# -----------------------------------------------------------------
def _forward_decode(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Decode-only via _npu_paged_attention (matches vllm-ascend)."""
import torch_npu # noqa: F401
torch_npu._npu_paged_attention(
query=query,
key_cache=self._key_cache,
value_cache=self._value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
)
return output
# -----------------------------------------------------------------
# Prefill without KV cache — _npu_flash_attention (TND layout)
# -----------------------------------------------------------------
def _forward_prefill_no_cache(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Prefill attention without KV cache via _npu_flash_attention.
Uses TND layout and a pre-built causal mask from metadata.
This matches vllm-ascend's _forward_prefill_no_cache.
"""
import torch_npu # noqa: F401
mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output,
)
return output[:num_tokens, :, :]
# -----------------------------------------------------------------
# Chunked prefill — npu_fused_infer_attention_score (TND layout)
# -----------------------------------------------------------------
def _forward_chunked_prefill(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Chunked prefill / mixed prefill+decode via
npu_fused_infer_attention_score, matching vllm-ascend's
_forward_v1_style."""
import torch_npu # noqa: F401
assert self._key_cache is not None
assert attn_metadata.attn_mask is not None
num_block, block_size, _, _ = self._key_cache.shape
key = self._key_cache.view(num_block, block_size, -1)
value = self._value_cache.view(num_block, block_size, -1)
# Trim query to actual tokens (npu_fused_infer_attention_score
# requires query.shape[0] == query_start_loc[-1])
actual_num_tokens = attn_metadata.query_start_loc[-1]
q = query[:actual_num_tokens]
out, _ = torch_npu.npu_fused_infer_attention_score(
query=q,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
output[:actual_num_tokens, :, :] = out[:actual_num_tokens, :, :]
return output