This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

@@ -0,0 +1,202 @@
import time
from collections import defaultdict
from typing import Optional
from vllm.utils import logger, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
PrefixCachingMetrics)
from vllm.v1.core.single_type_kv_cache_manager import \
get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
self.time_sec = int(time.time())
def log(self):
current_time_sec = int(time.time())
# Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%",
self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def update(self, num_tokens, num_computed_tokens):
# Note the function is called by scheduler
if self.log_stats and self.enable_prefix_caching:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
def set_cache_stats(self, num_tokens, num_computed_tokens):
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.hits = num_computed_tokens
self.prefix_cache_stats.queries = num_tokens
self.prefix_cache_stats.requests = 1
class CPUKVCacheManager:
def __init__(
self,
kv_cache_spec: KVCacheSpec,
num_cpu_blocks: int,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.block_size = kv_cache_spec.block_size
self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True,
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
)
# Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[
str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
log_stats=True)
# Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None):
return 0, False
request_id = request.request_id
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request_id]
if not block_hashes:
block_hashes = request.block_hashes
self.req_to_block_hashes[request_id] = block_hashes
max_cache_hit_length = request.num_tokens - 1
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.single_type_manager.kv_cache_spec,
use_eagle=self.use_eagle,
)
num_computed_tokens = len(computed_blocks[0]) * self.block_size
self.req_to_computed_blocks[request_id] = computed_blocks[0]
# We should touch these blocks in the concurrent scenarios.
self.block_pool.touch(computed_blocks)
# cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log()
return num_computed_tokens, False
def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks:
self.single_type_manager.block_pool.free_blocks(
reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int],
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids:
self._free_slots(request_id)
req_to_new_blocks = {}
for request_id, num_tokens in req_to_num_tokens.items():
if self.req_failed_to_allocate[request_id]:
continue
new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
))
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True
continue
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request_id, new_computed_blocks)
# Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks(
request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [
block.block_id for block in new_computed_blocks + new_blocks
]
return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request):
logger.debug(
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str):
logger.debug(
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
if request_id not in self.req_to_free:
logger.Error(
f"request {request_id} not in req_to_free, maybe bug!")
return
request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]:
self.single_type_manager.cache_blocks(
request,
self.req_to_num_tokens[request_id],
)
self._free_slots(request_id)
logger.debug(
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id]
def _free_slots(self, request_id: str):
# This function is designed to be reentrant.
self._release_ahead_touch(request_id)
self.single_type_manager.free(request_id)
self.req_to_block_hashes.pop(request_id, None)
self.req_to_computed_blocks.pop(request_id, None)
self.req_failed_to_allocate.pop(request_id, None)
self.req_to_num_tokens.pop(request_id, None)

View File

@@ -0,0 +1,269 @@
import math
import os
import pickle
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.config import KVTransferConfig, VllmConfig
from vllm.utils import get_dtype_size, logger, make_zmq_socket
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_npu.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager
@dataclass
class MLAConfig:
nope_dim: int
rope_dim: int
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if vllm_config.kv_transfer_config is not None:
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
return None
class MetadataServer:
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=f"worker-{os.getpid()}"):
logger.info(f"metadata client for worker {identity} started")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(request))
_ = self.socket.recv()
response = pickle.loads(self.socket.recv())
result, error = response
if error:
logger.exception(f"call metadata sever error: {error}")
raise error
if func_name == "init_cpu_kv_caches":
(memory_dict, layer_size, layer_dtype, mla_config) = result
# shared_memory_dict is recorded in self to close
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
def __del__(self):
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
for shm in self.shared_memory_dict.values():
shm.close()
def __init__(self, vllm_config: VllmConfig):
self.world_size = vllm_config.parallel_config.world_size
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
"ready": self.ready,
}
self.shared_memory = {} # type: ignore
self.num_cpu_blocks = -1
@staticmethod
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
try:
existing_shm = SharedMemory(name=name, create=False)
existing_shm.close()
existing_shm.unlink()
except FileNotFoundError:
pass
return SharedMemory(name=name, create=True, size=size)
def ready(self):
return True
def init_cpu_kv_caches(
self,
pp_rank: int,
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
# mla shares the same kv cache among different tp
if layer.use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if layer.use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
return self.shared_memory[(pp_rank, tp_rank)]
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
def serve_step(self):
client_id = self.socket.recv()
_ = self.socket.recv()
raw_msg = self.socket.recv()
try:
func_name, args, kwargs = pickle.loads(raw_msg)
except Exception as e:
response = (None, Exception(f"Invalid request: {str(e)}"))
else:
if func_name in self.functions:
try:
result = self.functions[func_name](*args, **kwargs)
response = (result, None) # type: ignore
except Exception as e:
logger.exception(f"metadata execute error: {e}")
response = (None, e) # type: ignore
else:
response = (None, NameError(f"Function {func_name} not found"))
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(response))
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
for shm in cached[0].values():
shm.close()
shm.unlink()
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
return
shutdown_requested = False
def _signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")
while True:
metadata_server.serve_step()
except SystemExit:
logger.info("Metadata server exiting.")
raise
except Exception as e:
logger.exception(f"Metadata server error: {e}.")
raise e
finally:
if metadata_server is not None:
metadata_server.shutdown()