Files
vllm-npu-plugin/vllm_npu/distributed/mooncake_connector.py
2026-02-10 23:08:39 +08:00

1264 lines
54 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import contextlib
import hashlib
import math
import os
import queue
import random
import struct
import threading
import time
from collections import defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
import msgspec
import numpy as np
import numpy.typing as npt
import torch
import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
get_tp_group)
from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_npu.envs as envs_ascend
from vllm_npu.ascend_config import get_ascend_config, init_ascend_config
from vllm_npu.distributed.mooncake.transfer_engine import get_global_te
from vllm_npu.distributed.utils import get_transfer_timeout_value
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
GET_META_MSG = b"get_meta_msg"
DONE_RECVING_MSG = b"done_recving_msg"
class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
engine_id: str
te_rpc_port: int
kv_caches_base_addr: list[int]
num_blocks: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: str
class KVCacheTaskTracker:
def __init__(self):
super().__init__()
self.done_task_lock = threading.Lock()
self.finished_requests: set[str] = set()
# Only used in prefill node. Tracks requests whose kv blocks freeing is
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
self.record_finished_requests: set[str] = set()
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
def add_not_transfer_request(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
def update_done_task_count(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
if request_id in self.delayed_free_requests:
self._remove_delayed_requests(request_id)
else:
self.record_finished_requests.add(request_id)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
expired_requests = self._retrieve_expired_requests()
finished_requests.update(expired_requests)
self.finished_requests.clear()
return finished_requests
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request."""
with self.done_task_lock:
if request_id not in self.record_finished_requests:
self.delayed_free_requests[request_id] = delay_start_time
else:
self.record_finished_requests.discard(request_id)
def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
expired_requests: set[str] = set()
# Free delayed requests if they exceed the timeout
current_time = time.time()
while self.delayed_free_requests:
request_id = next(iter(self.delayed_free_requests))
delay_start_time = self.delayed_free_requests[request_id]
if (current_time - delay_start_time
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popitem(last=False)
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
break
return expired_requests
def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests.pop(request_id)
class KVCacheSendingThread(threading.Thread):
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
side_channel_host: str, side_channel_port: int,
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheSendingThread")
self.tp_rank = tp_rank
self.decode_tp_size = decode_tp_size
self.local_engine_id = local_engine_id
self.side_channel_host = side_channel_host
self.side_channel_port = side_channel_port
self.metadata = metadata
self.ready_event = ready_event
self.kv_caches = kv_caches
self.task_tracker = KVCacheTaskTracker()
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
return self.task_tracker.get_and_clear_finished_requests()
def add_not_transfer_request(self, request_id: str):
self.task_tracker.add_not_transfer_request(request_id)
def add_delayed_request(self, request_id: str, delay_start_time: float):
return self.task_tracker.add_delayed_request(request_id,
delay_start_time)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(self.metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes",
str(size_in_bytes))
# Listen for new requests for metadata.
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
# us moving. We will switch when moving to etcd or where we have a
# single ZMQ socket in the scheduler.
handshake_port = self.side_channel_port + self.tp_rank
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
self.ready_event.set()
decoder = msgspec.msgpack.Decoder(type=tuple)
while True:
try:
frames = sock.recv_multipart()
if len(frames) < 2:
logger.error("Invalid message format: %s", frames)
continue
identity = frames[0]
payload = [f for f in frames[1:] if f != b""]
if len(payload) != 1:
logger.error("Invalid message format: %s", frames)
continue
msg = decoder.decode(payload[0])
if msg[0] == GET_META_MSG:
sock.send_multipart((identity, b"", encoded_data))
elif msg[0] == DONE_RECVING_MSG:
logger.debug("Got DONE_RECVING_MSG for request %s",
msg[1])
request_id = msg[1]
self.task_tracker.update_done_task_count(request_id)
# Acknowledge the request completion.
while True:
try:
# Send ACK to the sender.
sock.send_multipart(
(identity, b"", b"ACK"),
flags=zmq.NOBLOCK) # type: ignore
break
except zmq.Again: # type: ignore
# If the socket is not ready, retry sending.
logger.debug(
"Socket not ready, retrying to send ACK for "
"request %s", msg[1])
time.sleep(0.01)
else:
logger.error(
"Connection listener got unexpected message %s",
msg)
except Exception as e:
logger.error("Connection listener got exception %s: %s",
type(e), e)
class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine,
local_engine_id: str, local_handshake_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
self.local_engine_id = local_engine_id
self.local_handshake_port = local_handshake_port
self.engine = engine
self.ready_event = ready_event
self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict)
self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \
local_kv_caches_base_addr
self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict)
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sparse = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=32)
self.task_tracker = KVCacheTaskTracker()
self.encoder = msgspec.msgpack.Encoder()
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
self.remote_sockets_lock = threading.Lock()
self.remote_sockets: dict[ # type: ignore
str, deque[zmq.Socket]] = defaultdict( # type: ignore
deque)
self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds
self.vllm_config = vllm_config
self.model_config = self.vllm_config.model_config
self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads
self.kv_caches = kv_caches
def add_request(self, request_id: str, local_block_ids: list[int],
remote_block_ids: list[int], remote_engine_id: str,
remote_host: str, remote_handshake_port: int, offset: int,
num_need_pulls: int):
"""Add a new request to the queue for processing."""
logger.debug(f"Adding request {request_id} to the queue.")
self.request_queue.put({
"request_id": request_id,
"local_block_ids": local_block_ids,
"remote_block_ids": remote_block_ids,
"remote_engine_id": remote_engine_id,
"remote_host": remote_host,
"remote_handshake_port": remote_handshake_port,
"offset": offset,
"num_need_pulls": num_need_pulls
})
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
return self.task_tracker.get_and_clear_finished_requests()
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
request_id = req_meta["request_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
offset = req_meta["offset"]
num_need_pulls = req_meta["num_need_pulls"]
try:
logger.debug(
f"Starting to transfer KV cache for request {request_id}.")
self._transfer_kv_cache(req_meta)
logger.debug(
f"Finished transferring KV cache for request {request_id}.")
except Exception as e:
logger.error("Failed to transfer KV cache for request "
f"{request_id}: {e}")
finally:
# Always send the done signal to the remote host to ensure proper
# resource cleanup. Failing to do so may cause a memory leak on the
# remote host.
self._send_done_recv_signal(request_id, remote_host,
remote_handshake_port)
if offset == num_need_pulls - 1:
self.task_tracker.update_done_task_count(request_id)
self.request_queue.task_done()
def _transfer_kv_cache(self, req_meta: dict[str, Any]):
"""Handle a KV cache transfer request."""
request_id = req_meta["request_id"]
remote_block_ids = req_meta["remote_block_ids"]
local_block_ids = req_meta["local_block_ids"]
remote_engine_id = req_meta["remote_engine_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
offset = req_meta["offset"]
self.num_need_pulls = req_meta["num_need_pulls"]
# Full prefix cache hit: do not need to read remote blocks, just notify
# P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Check if we have the remote metadata cached.
if remote_engine_id not in self.kv_caches_base_addr or \
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
self._get_remote_metadata(remote_host, remote_handshake_port)
if self.num_need_pulls == 1:
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)
else:
remote_block_ids = list(map(lambda x: [x], remote_block_ids))
local_block_ids = list(map(lambda x: [x], local_block_ids))
grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids
num_transfer_groups = len(grouped_remote_block_ids)
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
local_kv_caches_base_addrs = \
self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port]
remote_transfer_port = self.remote_te_port[remote_engine_id][
remote_handshake_port]
num_blocks = len(local_block_ids)
session_id = f"{remote_host}:{remote_transfer_port}"
req_start_time = time.perf_counter()
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sparse:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
inner_block_len = block_len // self.num_need_pulls
for remote_block_id, local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + local_block_id[
0] * block_len + offset * inner_block_len
dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len
length = inner_block_len * len(local_block_id)
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
ret = self.engine.batch_transfer_sync_read(session_id, src_list,
dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s",
req_meta["request_id"])
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
req_end_time = time.perf_counter()
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
logger.info(
"KV cache transfer for request %s took %.2f ms (%d groups,"
" %d blocks). local_ip %s local_device_id %s remote_session_id %s",
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
get_ip(), self.tp_rank, session_id)
if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1:
self._cat_kv_cache(grouped_local_block_ids)
def _cat_kv_cache(self, block_ids: list[list[int]]):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
kv_shape = k_cache.shape
dtype = k_cache.dtype
device = k_cache.device
head_dim = self.model_config.hf_config.head_dim
block_size = self.vllm_config.cache_config.block_size
num_kv_head = max(
self.model_config.hf_config.num_key_value_heads // self.tp_size, 1)
flat_block_ids = [item for sublist in block_ids for item in sublist]
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
num_blocks = len(flat_block_ids)
block_len = num_blocks * block_size
# Create device tensors for copy operations
block_table = block_ids_tensor.view(1, -1).to(device=device)
block_len_tensor = torch.tensor([block_len],
dtype=torch.int32).to(device=device)
seq_start_tensor = torch.tensor([0],
dtype=torch.int32).to(device=device)
# Initialize buffers
k_buffer = torch.empty(block_len,
num_kv_head,
head_dim,
dtype=dtype,
device=device)
v_buffer = torch.empty(block_len,
num_kv_head,
head_dim,
dtype=dtype,
device=device)
# Create slot mapping for reshape operations
block_offsets = torch.arange(0, block_size, dtype=torch.int32)
slot_mapping = (block_offsets.reshape(
(1, block_size)) + block_ids_tensor.reshape(
(num_blocks, 1)) * block_size)
slot_mapping = slot_mapping.flatten().to(device=device)
# Process each layer in the KV cache
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
if len(
k_cache_layer.shape
) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim]
k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1],
num_kv_head, head_dim)
v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1],
num_kv_head, head_dim)
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
k_cache_layer,
v_cache_layer,
block_table,
block_len_tensor,
seq_starts=seq_start_tensor,
key=k_buffer,
value=v_buffer,
)
# Transpose KV cache
k_buffer = self._transpose_kv_cache_between_head(
k_buffer, num_blocks, block_size, block_len, num_kv_head)
v_buffer = self._transpose_kv_cache_between_head(
v_buffer, num_blocks, block_size, block_len, num_kv_head)
# Reshape and cache the processed buffers
torch_npu._npu_reshape_and_cache(
key=k_buffer,
value=v_buffer,
key_cache=k_cache_layer,
value_cache=v_cache_layer,
slot_indices=slot_mapping,
)
# Clean up buffers
del k_buffer, v_buffer
def _transpose_kv_cache_between_head(self, buffer: torch.Tensor,
num_blocks: int, block_size: int,
block_len: int,
num_kv_head: int) -> torch.Tensor:
buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1)
buffer.transpose_(1, 2)
return buffer.contiguous().view(block_len, num_kv_head, -1)
def _get_remote_metadata(self, remote_host: str,
remote_handshake_port: int) -> None:
"""Get the metadata from the remote host."""
sock: Optional[zmq.Socket] = None # type: ignore
try:
sock = self._get_remote_socket(remote_host, remote_handshake_port)
ensure_zmq_send(sock, self.encoder.encode((GET_META_MSG, "")))
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller)
agent_meta = self.decoder.decode(metadata_bytes)
engine_id = agent_meta.engine_id
assert engine_id != self.local_engine_id, (
f"Conflict engine id {engine_id} with local engine id "
f"{self.local_engine_id}.")
self.kv_caches_base_addr[engine_id][remote_handshake_port] = \
agent_meta.kv_caches_base_addr
self.remote_te_port[engine_id][remote_handshake_port] = \
agent_meta.te_rpc_port
finally:
if sock is not None:
self._return_remote_socket(sock, remote_host,
remote_handshake_port)
logger.debug("Returned socket to pool for %s:%d", remote_host,
remote_handshake_port)
def _send_done_recv_signal(self, request_id: str, remote_host: str,
remote_handshake_port: int):
logger.debug("Sending done recving signal for request %s to %s:%d",
request_id, remote_host, remote_handshake_port)
sock: Optional[zmq.Socket] = None # type: ignore
try:
sock = self._get_remote_socket(remote_host, remote_handshake_port)
data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id))
ensure_zmq_send(sock, data_bytes)
resp = ensure_zmq_recv(sock,
self.remote_poller,
timeout=self.timeout)
logger.debug(
f"Received response for request {request_id}: {resp.decode('utf-8')}"
)
if resp != b"ACK":
logger.error("Failed to receive ACK for request %s from %s:%d",
request_id, remote_host, remote_handshake_port)
raise RuntimeError(
f"Failed to receive ACK, resp: {resp.decode('utf-8')}")
finally:
if sock is not None:
self._return_remote_socket(sock, remote_host,
remote_handshake_port)
logger.debug("Returned socket to pool for %s:%d", remote_host,
remote_handshake_port)
def _get_remote_socket(
self, remote_host: str,
remote_handshake_port: int) -> zmq.Socket: # type: ignore
"""Get a socket to the remote host."""
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)
with self.remote_sockets_lock:
if self.remote_sockets[remote_path]:
return self.remote_sockets[remote_path].popleft()
ctx = zmq.Context() # type: ignore
sock = make_zmq_socket(
ctx=ctx,
path=remote_path,
socket_type=zmq.REQ, # type: ignore
bind=False)
sock.setsockopt(
zmq.SNDTIMEO, # type: ignore
int(self.timeout * 1000))
self.remote_poller.register(sock, zmq.POLLIN) # type: ignore
return sock
def _return_remote_socket(
self,
sock: zmq.Socket, # type: ignore
remote_host: str,
remote_handshake_port: int) -> None:
"""Return the remote socket to the pool."""
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)
with self.remote_sockets_lock:
self.remote_sockets[remote_path].append(sock)
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
self.requests_to_send: dict[str, float] = {}
def add_new_req(
self,
request_id: str,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
)
class MooncakeConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \
MooncakeConnectorScheduler(vllm_config, str(self.engine_id))
self.connector_worker: Optional[MooncakeConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(
vllm_config, str(self.engine_id))
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeConnector does not do layerwise saving."""
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""MooncakeConnector does not save explicitly."""
pass
def wait_for_save(self):
"""MooncakeConnector does not save explicitly."""
pass
class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
init_ascend_config(vllm_config)
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
self.local_ip = get_ip()
logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip()
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens, params)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0
# Note: We use the full token count as transmit data here.
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
return count, count > 0
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, params)
if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", params)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = MooncakeConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
meta.requests_to_send = self._reqs_need_send
self._reqs_need_send = {}
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
computed_block_ids = block_ids
delay_free_blocks = len(computed_block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
self._reqs_need_send[request.request_id] = time.time()
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
last_token_id=request.output_token_ids[-1],
)
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._get_prefill_decode_size(vllm_config)
os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(
get_transfer_timeout_value())
if self._prefill_tp_size < self._decode_tp_size:
raise ValueError(
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
f" or equal to the decode_tp_size: {self._decode_tp_size}")
# Metadata.
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
self.engine_id = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_group = get_tp_group()
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
self.handshake_port = self.side_channel_port + self.tp_rank
self.sockets: dict = {}
# get tp device id
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
# introducing some changes
device_ids_str = envs_ascend.PHYSICAL_DEVICES
if device_ids_str is None:
device_ids = list(
range(self.dp_rank * self.tp_size,
(self.dp_rank + 1) * self.tp_size))
else:
device_ids = list(map(int, device_ids_str.split(',')))
start_index = self.dp_rank * self.tp_size
end_index = start_index + self.tp_size
if len(device_ids) < end_index:
raise ValueError(
f"Not enough physical devices available for DP rank {self.dp_rank}. "
f"Expected at least {end_index} devices, but found {len(device_ids)} "
"in PHYSICAL_DEVICES.")
device_ids = device_ids[start_index:end_index]
assert len(device_ids) > self.tp_rank # type: ignore
self.device_id = device_ids[self.tp_rank] # type: ignore
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', True):
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = get_global_te(hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches.
self.kv_send_thread: Optional[KVCacheSendingThread] = None
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
# kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
if self.vllm_config.model_config.is_deepseek_mla:
self.num_need_pulls = 1
else:
num_d_block_heads = max(1,
self.num_key_value_heads // self.tp_size)
num_p_block_heads = max(
1, self.num_key_value_heads // self._prefill_tp_size)
self.num_need_pulls = num_d_block_heads // num_p_block_heads
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
prefill_parallel_config: dict[
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {})
assert "tp_size" in prefill_parallel_config.keys()
self._prefill_tp_size = prefill_parallel_config["tp_size"]
assert "dp_size" in prefill_parallel_config.keys()
self._prefill_dp_size = prefill_parallel_config["dp_size"]
# get decode tp and dp size from extra config
decode_parallel_config: dict[
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {})
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"]
def _initialize(
self,
hostname: str,
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
device_name = device_name if device_name is not None else ""
ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend",
device_name)
if ret_value != 0:
raise RuntimeError(
f"Mooncake initialization failed with ret_value: {ret_value}")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# eager:[num_block, block_size, num_head, hidden_dim]
# torchair:[num_block, block_size, num_head*hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = len(
first_kv_cache.shape
) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla, self.use_sparse, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
else:
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
# After KV Caches registered, start the sending or receiving thread.
metadata = MooncakeAgentMetadata(
engine_id=self.engine_id,
te_rpc_port=self.te_rpc_port,
kv_caches_base_addr=kv_caches_base_addr,
num_blocks=self.num_blocks,
)
ready_event = threading.Event()
if self.kv_role == 'kv_producer':
self.kv_send_thread = KVCacheSendingThread(
self.tp_rank, self._decode_tp_size, self.engine_id,
self.side_channel_host, self.side_channel_port, metadata,
ready_event, self.kv_caches)
self.kv_send_thread.start()
else:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self.engine, self.engine_id,
self.handshake_port, kv_caches_base_addr, self.block_len,
ready_event, self.vllm_config, self.kv_caches)
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
ret_value = self.engine.register_memory(ptr, length)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed.")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role == 'kv_producer' else set())
done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role == 'kv_consumer' else set())
if self.tp_rank == 0:
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d", len(done_sending), len(done_recving))
return done_sending, done_recving
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
"""Start loading KV blocks from remote engine."""
for req_id, meta in metadata.requests.items():
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
choosen_rank_list = self._get_remote_tp_rank(req_id)
remote_handshake_port_list = [
x + meta.remote_port for x in choosen_rank_list
]
for i in range(self.num_need_pulls):
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[i],
offset=i,
num_need_pulls=self.num_need_pulls)
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
else:
self.kv_send_thread.add_not_transfer_request(req_id)
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]:
if self._prefill_tp_size == self._decode_tp_size:
result = list(map(lambda x: [x], range(self._prefill_tp_size)))
return result
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
sampled_nums = []
ori_data = np.arange(self._prefill_tp_size)
# random split prefill tp list
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
# use deepseek mla, num_key_value_heads == 128, but consider as 1
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
num_groups = len(ori_data) // num_kv_head
ori_data = ori_data.reshape(-1, num_groups)
rand_group_index = rand.sample(range(num_groups), \
max(self._decode_tp_size // num_kv_head, 1)) # random choose a group
choosen_group = ori_data[:, [rand_group_index]]
flattened = choosen_group.reshape(-1).tolist()
sampled_nums = [
flattened[i:i + self.num_need_pulls]
for i in range(0, len(flattened), self.num_need_pulls)
]
# non-random split
else:
group_size = self._prefill_tp_size // self._decode_tp_size
for i in range(self._decode_tp_size):
ori_data_slice = ori_data[i * group_size:(i + 1) * group_size]
sampled_nums.append(ori_data_slice.tolist())
return sampled_nums
@contextlib.contextmanager
def zmq_ctx(socket_type: Any,
addr: str) -> Iterator[zmq.Socket]: # type: ignore
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: Optional[zmq.Context] = None # type: ignore
try:
ctx = zmq.Context() # type: ignore
yield make_zmq_socket(ctx=ctx,
path=addr,
socket_type=socket_type,
bind=socket_type == zmq.ROUTER) # type: ignore
finally:
if ctx is not None:
ctx.destroy(linger=0)
def group_concurrent_contiguous(
src: List[int], dst: List[int]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64)
dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64)
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1)
| (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
def string_to_int64_hash(input_str):
"""
Hash the string using SHA-256 and convert it into an int64 integer.
"""
hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest()
trunked_bytes = hashed_bytes[:8]
uint64_value = struct.unpack("<Q", trunked_bytes)[0]
return uint64_value
def ensure_zmq_send(
socket: zmq.Socket, # type: ignore
data: bytes,
max_retries: int = 3):
retries_left = max_retries
while True:
try:
socket.send(data)
return
except zmq.ZMQError as e: # type: ignore
retries_left -= 1
if retries_left > 0:
logger.warning(
f"Send failed: {e}, retrying... ({retries_left} "
"attempts left)")
time.sleep(0.1)
else:
logger.error(f"Send failed after all retries: {e}")
raise RuntimeError(f"Failed to send data after {max_retries} "
f"retries: {e}")
def ensure_zmq_recv(
socket: zmq.Socket, # type: ignore
poller: zmq.Poller, # type: ignore
timeout: float = 1.0,
max_retries: int = 3) -> bytes:
retries_left = max_retries
while True:
try:
if dict(poller.poll(int(timeout * 1000))): # milliseconds
data = socket.recv()
return data
else:
raise zmq.ZMQError("Receive timeout") # type: ignore
except zmq.ZMQError as e: # type: ignore
retries_left -= 1
if retries_left > 0:
logger.warning(f"Receive failed: {e}, retrying... "
f"({retries_left} attempts left)")
time.sleep(0.1)
else:
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")