This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

@@ -0,0 +1,269 @@
import math
import os
import pickle
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.config import KVTransferConfig, VllmConfig
from vllm.utils import get_dtype_size, logger, make_zmq_socket
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_npu.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager
@dataclass
class MLAConfig:
nope_dim: int
rope_dim: int
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if vllm_config.kv_transfer_config is not None:
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
return None
class MetadataServer:
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=f"worker-{os.getpid()}"):
logger.info(f"metadata client for worker {identity} started")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(request))
_ = self.socket.recv()
response = pickle.loads(self.socket.recv())
result, error = response
if error:
logger.exception(f"call metadata sever error: {error}")
raise error
if func_name == "init_cpu_kv_caches":
(memory_dict, layer_size, layer_dtype, mla_config) = result
# shared_memory_dict is recorded in self to close
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
def __del__(self):
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
for shm in self.shared_memory_dict.values():
shm.close()
def __init__(self, vllm_config: VllmConfig):
self.world_size = vllm_config.parallel_config.world_size
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
"ready": self.ready,
}
self.shared_memory = {} # type: ignore
self.num_cpu_blocks = -1
@staticmethod
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
try:
existing_shm = SharedMemory(name=name, create=False)
existing_shm.close()
existing_shm.unlink()
except FileNotFoundError:
pass
return SharedMemory(name=name, create=True, size=size)
def ready(self):
return True
def init_cpu_kv_caches(
self,
pp_rank: int,
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
# mla shares the same kv cache among different tp
if layer.use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if layer.use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
return self.shared_memory[(pp_rank, tp_rank)]
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
def serve_step(self):
client_id = self.socket.recv()
_ = self.socket.recv()
raw_msg = self.socket.recv()
try:
func_name, args, kwargs = pickle.loads(raw_msg)
except Exception as e:
response = (None, Exception(f"Invalid request: {str(e)}"))
else:
if func_name in self.functions:
try:
result = self.functions[func_name](*args, **kwargs)
response = (result, None) # type: ignore
except Exception as e:
logger.exception(f"metadata execute error: {e}")
response = (None, e) # type: ignore
else:
response = (None, NameError(f"Function {func_name} not found"))
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(response))
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
for shm in cached[0].values():
shm.close()
shm.unlink()
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
return
shutdown_requested = False
def _signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")
while True:
metadata_server.serve_step()
except SystemExit:
logger.info("Metadata server exiting.")
raise
except Exception as e:
logger.exception(f"Metadata server error: {e}.")
raise e
finally:
if metadata_server is not None:
metadata_server.shutdown()