diff --git a/vllm_npu/attention/attention_v1.py b/vllm_npu/attention/attention_v1.py index 3c0d687..39c2d87 100644 --- a/vllm_npu/attention/attention_v1.py +++ b/vllm_npu/attention/attention_v1.py @@ -84,13 +84,13 @@ class AscendAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, **kwargs, - ) -> Tuple[int, int, int, int]: - """KV cache shape: (num_blocks, block_size, num_kv_heads, head_size). + ) -> Tuple[int, ...]: + """KV cache shape: (2, num_blocks, block_size, num_kv_heads, head_size). - Key and value caches are allocated as two separate tensors with - this shape; they are paired in a ``(key_cache, value_cache)`` tuple. + The leading ``2`` stores key and value caches in a single tensor. + They are split via ``kv_cache.unbind(0)`` at runtime. """ - return (num_blocks, block_size, num_kv_heads, head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod @@ -346,9 +346,9 @@ class AscendAttentionBackendImpl(AttentionImpl): # ---------------------------------------------------------- # Step 1: Update KV cache # ---------------------------------------------------------- - if len(kv_cache) > 1: + if kv_cache is not None and len(kv_cache.shape) > 1: if self._key_cache is None: - self._key_cache, self._value_cache = kv_cache[0], kv_cache[1] + self._key_cache, self._value_cache = kv_cache.unbind(0) slots = attn_metadata.slot_mapping torch_npu._npu_reshape_and_cache(