fix: pure pytorch reshape_and_cache + _npu_flash_attention prefill

This commit is contained in:
2026-02-10 20:33:14 +08:00
parent 30cf7ccd1f
commit 5337842e92

View File

@@ -322,24 +322,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
slot_mapping indices. slot_mapping indices.
""" """
import torch_npu # noqa: F401
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
if self._key_cache is None: if self._key_cache is None:
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
# Ensure contiguous tensors for the NPU op slots = attn_metadata.slot_mapping
key = key.contiguous() key_to_cache = key[:attn_metadata.num_actual_tokens]
value = value.contiguous() val_to_cache = value[:attn_metadata.num_actual_tokens]
slots = attn_metadata.slot_mapping.long() # indices must be long
torch_npu._npu_reshape_and_cache( # Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env)
key, block_size = self._key_cache.shape[1]
value, block_idx = slots // block_size
self._key_cache, block_offset = slots % block_size
self._value_cache, self._key_cache[block_idx, block_offset] = key_to_cache
slots, self._value_cache[block_idx, block_offset] = val_to_cache
)
return key, value return key, value