fix: replace ATB reshape_and_cache with pure PyTorch indexing

This commit is contained in:
2026-02-10 19:56:47 +08:00
parent 101435817a
commit b8b4516b98

View File

@@ -351,13 +351,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
self._key_cache, self._value_cache = kv_cache.unbind(0)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
key_cache=self._key_cache,
value_cache=self._value_cache,
slot_indices=slots,
)
# Pure PyTorch reshape_and_cache (avoids ATB dependency)
key_to_cache = key[:num_actual_tokens]
val_to_cache = value[:num_actual_tokens]
block_size = self._key_cache.shape[1]
block_idx = slots // block_size
block_offset = slots % block_size
self._key_cache[block_idx, block_offset] = key_to_cache
self._value_cache[block_idx, block_offset] = val_to_cache
# ----------------------------------------------------------
# Step 2: Compute attention