fix: pure pytorch reshape_and_cache + _npu_flash_attention prefill

This commit is contained in:
2026-02-10 20:33:14 +08:00
parent 30cf7ccd1f
commit 5337842e92

View File

@@ -322,24 +322,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
slot_mapping indices.
"""
import torch_npu # noqa: F401
if kv_cache.numel() > 0:
if self._key_cache is None:
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
# Ensure contiguous tensors for the NPU op
key = key.contiguous()
value = value.contiguous()
slots = attn_metadata.slot_mapping.long() # indices must be long
slots = attn_metadata.slot_mapping
key_to_cache = key[:attn_metadata.num_actual_tokens]
val_to_cache = value[:attn_metadata.num_actual_tokens]
torch_npu._npu_reshape_and_cache(
key,
value,
self._key_cache,
self._value_cache,
slots,
)
# Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env)
block_size = self._key_cache.shape[1]
block_idx = slots // block_size
block_offset = slots % block_size
self._key_cache[block_idx, block_offset] = key_to_cache
self._value_cache[block_idx, block_offset] = val_to_cache
return key, value