Files
vllm-npu-plugin/vllm_npu/worker/block_table.py
2026-02-10 23:08:39 +08:00

313 lines
13 KiB
Python

from typing import Optional, Union
import numpy as np
import torch
from vllm.distributed import get_dcp_group
from vllm.utils import cdiv
class BlockTable:
def __init__(self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_sizes: Union[list[int], None] = None):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
self.physical_block_size = block_size
# If kernel_sizes is None or [0], use physical block size (no splitting)
if kernel_sizes is None or kernel_sizes == [0]:
self.block_size = block_size
self.logical_block_size = block_size
self.blocks_per_phys_block = 1
self.use_hybrid_blocks = False
else:
# Find the first kernel size that divides physical_block_size evenly
selected_kernel_size = None
for kernel_size in kernel_sizes:
if kernel_size > 0 \
and self.physical_block_size % kernel_size == 0:
selected_kernel_size = kernel_size
break
if selected_kernel_size is None:
raise ValueError(
f"None of the kernel sizes {kernel_sizes} can divide "
f"physical block size {self.physical_block_size} evenly")
self.block_size = selected_kernel_size
self.logical_block_size = selected_kernel_size
self.blocks_per_phys_block = (self.physical_block_size //
self.logical_block_size)
if self.blocks_per_phys_block > 1:
self.use_hybrid_blocks = True
else:
self.use_hybrid_blocks = False
if self.use_hybrid_blocks:
logical_table_size = (max_num_blocks_per_req *
self.blocks_per_phys_block)
else:
logical_table_size = max_num_blocks_per_req
self.block_table = torch.zeros(
(max_num_reqs, logical_table_size),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, logical_table_size),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.kernel_sizes = kernel_sizes
def append_row(
self,
block_ids,
row_idx: int,
) -> None:
if not block_ids:
return
block_ids = np.array(block_ids)
if self.use_hybrid_blocks:
block_ids = self._convert_physical_to_logical_blocks(block_ids)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] += num_blocks
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
if self.dcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
logical_block_idx = positions // virtual_block_size
# Account for the expanded logical table
# (always needed with unified tensor)
# Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this
block_table_indices = (req_indices * self.max_num_blocks_per_req *
self.blocks_per_phys_block +
logical_block_idx)
block_numbers = self.block_table_np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
# Calculate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
assert self.kernel_sizes is not None
if self.block_size == self.kernel_sizes[0]:
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
logical_block_idx = positions // self.block_size
# Account for the expanded logical table
# (always needed with unified tensor)
# Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this
block_table_indices = (
req_indices * self.max_num_blocks_per_req *
self.blocks_per_phys_block + logical_block_idx)
block_numbers = self.block_table_np.ravel(
)[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping[:num_tokens].copy_(
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
def _convert_physical_to_logical_blocks(
self, physical_blocks: np.ndarray) -> np.ndarray:
"""Convert physical block IDs to logical block IDs."""
if not self.use_hybrid_blocks:
return physical_blocks
# Create logical block IDs by splitting each physical block
logical_blocks: list[int] = []
for phys_block in physical_blocks:
# Convert physical block to multiple logical blocks
# Physical block 1 becomes logical blocks
# [1*split_ratio, 1*split_ratio+1, ...]
# But we need to account for the fact that block 0 is special
base_logical = phys_block * self.blocks_per_phys_block
logical_blocks.extend(
range(base_logical, base_logical + self.blocks_per_phys_block))
return np.array(logical_blocks, dtype=np.int32)
def get_device_tensor(self) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
kernel_sizes: Optional[list[list[int]]] = None) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
# Ensure kernel_sizes matches block_sizes length
elif len(kernel_sizes) == 1 and len(block_sizes) > 1:
kernel_sizes = kernel_sizes * len(block_sizes)
elif len(kernel_sizes) != len(block_sizes):
raise ValueError(
f"kernel_sizes length ({len(kernel_sizes)}) must match "
f"block_sizes length ({len(block_sizes)})")
# Use zip to pair block_sizes with kernel_sizes one-to-one
self.block_tables = [
BlockTable(
block_size, max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device, kernel_size_list)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
]
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]