fix: accept cache_dtype_str in get_kv_cache_shape

This commit is contained in:
2026-02-10 19:23:20 +08:00
parent c3631d65c2
commit a274fd82ad

View File

@@ -83,6 +83,7 @@ class AscendAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
**kwargs,
) -> Tuple[int, int, int, int]: ) -> Tuple[int, int, int, int]:
"""KV cache shape: (num_blocks, block_size, num_kv_heads, head_size). """KV cache shape: (num_blocks, block_size, num_kv_heads, head_size).
@@ -91,6 +92,7 @@ class AscendAttentionBackend(AttentionBackend):
""" """
return (num_blocks, block_size, num_kv_heads, head_size) return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: List[torch.Tensor], src_kv_cache: List[torch.Tensor],