From 5337842e92d31744e0b7cd11b6b0fedbe75db43f Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 20:33:14 +0800 Subject: [PATCH] fix: pure pytorch reshape_and_cache + _npu_flash_attention prefill --- vllm_npu/attention/attention_v1.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index f3a9894..10029a3 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -322,24 +322,20 @@ class AscendAttentionBackendImpl(AttentionImpl): Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via slot_mapping indices. """ - import torch_npu # noqa: F401 - if kv_cache.numel() > 0: if self._key_cache is None: self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] - # Ensure contiguous tensors for the NPU op - key = key.contiguous() - value = value.contiguous() - slots = attn_metadata.slot_mapping.long() # indices must be long + slots = attn_metadata.slot_mapping + key_to_cache = key[:attn_metadata.num_actual_tokens] + val_to_cache = value[:attn_metadata.num_actual_tokens] - torch_npu._npu_reshape_and_cache( - key, - value, - self._key_cache, - self._value_cache, - slots, - ) + # Use pure-PyTorch indexing (ATB reshape_and_cache crashes on this env) + block_size = self._key_cache.shape[1] + block_idx = slots // block_size + block_offset = slots % block_size + self._key_cache[block_idx, block_offset] = key_to_cache + self._value_cache[block_idx, block_offset] = val_to_cache return key, value