mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
fix: KV cache shape needs leading 2 dim for key+value pair
This commit is contained in:
@@ -84,13 +84,13 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[int, int, int, int]:
|
) -> Tuple[int, ...]:
|
||||||
"""KV cache shape: (num_blocks, block_size, num_kv_heads, head_size).
|
"""KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size).
|
||||||
|
|
||||||
Key and value caches are allocated as two separate tensors with
|
The leading ``2`` stores key and value caches in a single tensor.
|
||||||
this shape; they are paired in a ``(key_cache, value_cache)`` tuple.
|
They are split via ``kv_cache.unbind(0)`` at runtime.
|
||||||
"""
|
"""
|
||||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -346,9 +346,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# Step 1: Update KV cache
|
# Step 1: Update KV cache
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
if len(kv_cache) > 1:
|
if kv_cache is not None and len(kv_cache.shape) > 1:
|
||||||
if self._key_cache is None:
|
if self._key_cache is None:
|
||||||
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
self._key_cache, self._value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
torch_npu._npu_reshape_and_cache(
|
torch_npu._npu_reshape_and_cache(
|
||||||
|
|||||||
Reference in New Issue
Block a user