fix: KV cache shape needs leading 2 dim for key+value pair

This commit is contained in:
2026-02-10 19:27:10 +08:00
parent a274fd82ad
commit 7120cd803b

View File

@@ -84,13 +84,13 @@ class AscendAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
**kwargs,
) -> Tuple[int, int, int, int]:
"""KV cache shape: (num_blocks, block_size, num_kv_heads, head_size).
) -> Tuple[int, ...]:
"""KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size).
Key and value caches are allocated as two separate tensors with
this shape; they are paired in a ``(key_cache, value_cache)`` tuple.
The leading ``2`` stores key and value caches in a single tensor.
They are split via ``kv_cache.unbind(0)`` at runtime.
"""
return (num_blocks, block_size, num_kv_heads, head_size)
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
@@ -346,9 +346,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
# ----------------------------------------------------------
# Step 1: Update KV cache
# ----------------------------------------------------------
if len(kv_cache) > 1:
if kv_cache is not None and len(kv_cache.shape) > 1:
if self._key_cache is None:
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
self._key_cache, self._value_cache = kv_cache.unbind(0)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(