From b8b4516b988847968bfbacd8fcc87a5427773e37 Mon Sep 17 00:00:00 2001 From: handsomezhuzhu <2658601135@qq.com> Date: Tue, 10 Feb 2026 19:56:47 +0800 Subject: [PATCH] fix: replace ATB reshape_and_cache with pure PyTorch indexing --- vllm_npu/attention/attention_v1.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 39c2d87..da9e7b3 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -351,13 +351,15 @@ class AscendAttentionBackendImpl(AttentionImpl): self._key_cache, self._value_cache = kv_cache.unbind(0) slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache( - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - key_cache=self._key_cache, - value_cache=self._value_cache, - slot_indices=slots, - ) + # Pure PyTorch reshape_and_cache (avoids ATB dependency) + key_to_cache = key[:num_actual_tokens] + val_to_cache = value[:num_actual_tokens] + block_size = self._key_cache.shape[1] + block_idx = slots // block_size + block_offset = slots % block_size + self._key_cache[block_idx, block_offset] = key_to_cache + self._value_cache[block_idx, block_offset] = val_to_cache + # ---------------------------------------------------------- # Step 2: Compute attention