fix: revert to _npu_reshape_and_cache (contiguous) and _npu_flash_attention

This commit is contained in:
2026-02-10 20:29:18 +08:00
parent a58c3fe973
commit 30cf7ccd1f

View File

@@ -322,21 +322,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
slot_mapping indices. slot_mapping indices.
""" """
import torch_npu # noqa: F401
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
if self._key_cache is None: if self._key_cache is None:
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping # Ensure contiguous tensors for the NPU op
key_to_cache = key[:attn_metadata.num_actual_tokens] key = key.contiguous()
val_to_cache = value[:attn_metadata.num_actual_tokens] value = value.contiguous()
slots = attn_metadata.slot_mapping.long() # indices must be long
# Use pure-PyTorch indexing (ATB reshape_and_cache may fail torch_npu._npu_reshape_and_cache(
# depending on environment; this is functionally identical) key,
block_size = self._key_cache.shape[1] value,
block_idx = slots // block_size self._key_cache,
block_offset = slots % block_size self._value_cache,
self._key_cache[block_idx, block_offset] = key_to_cache slots,
self._value_cache[block_idx, block_offset] = val_to_cache )
return key, value return key, value
@@ -450,28 +453,33 @@ class AscendAttentionBackendImpl(AttentionImpl):
output: torch.Tensor, output: torch.Tensor,
num_tokens: int, num_tokens: int,
) -> torch.Tensor: ) -> torch.Tensor:
"""Prefill attention without KV cache (self-attention).""" """Prefill attention without KV cache (self-attention) using _npu_flash_attention."""
import torch_npu # noqa: F401 import torch_npu # noqa: F401
cum_seq_len = attn_metadata.query_start_loc[1:].tolist() # Huawei uses _npu_flash_attention for prefill
# Ensure contiguous inputs
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
attn_out = torch_npu.npu_fusion_attention( # mask needs to be contiguous and cast to expected format if needed
query[:num_tokens], # but _npu_flash_attention handles generic mask?
key[:num_tokens], # Huawei code: mask = attn_metadata.attn_mask...
value[:num_tokens], # We'll pass it as is, assuming AscendMetadataBuilder created it correctly.
head_num=self.num_heads,
input_layout="TND", torch_npu._npu_flash_attention(
scale=self.scale, query=query,
sparse_mode=0, key=key,
atten_mask=attn_metadata.attn_mask, value=value,
pre_tockens=2147483647, mask=attn_metadata.attn_mask,
next_tockens=0, seq_len=attn_metadata.seq_lens,
actual_seq_qlen=cum_seq_len, scale_value=self.scale,
actual_seq_kvlen=cum_seq_len, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output
) )
output[:num_tokens] = attn_out[0] return output[:num_tokens]
return output
# ----------------------------------------------------------------- # -----------------------------------------------------------------
# Chunked prefill — mixed prefill+decode via npu_fusion_attention # Chunked prefill — mixed prefill+decode via npu_fusion_attention