This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

@@ -0,0 +1,449 @@
import array
import hashlib
import json
import os
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import cdiv, logger
from vllm.v1.core.sched.output import NewRequestData
@dataclass
class MooncakeEngineMetadata:
"""name of the LLM model"""
model_name: str
""" world size when running under a distributed setting """
world_size: int
""" worker id when running under a distributed setting """
worker_id: int
""" the format of kv tensors """
kv_dtype: torch.dtype
""" the shape of kv tensors """
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
kv_shape: tuple[int, int, int, int, int]
block_size: int = 128
""" whether use MLA"""
use_mla: bool = False
@dataclass(order=True)
class MooncakeEngineKey:
model_name: str
world_size: int
worker_id: int
chunk_hash: str
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}")
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerMooncakeEngineKey(
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
layer_id,
))
return keys
def to_dict(self):
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
return {
"__type__": "CacheEngineKey",
"model_name": self.model_name,
"world_size": self.world_size,
"worker_id": self.worker_id,
"chunk_hash": self.chunk_hash,
}
@staticmethod
def from_dict(d):
return MooncakeEngineKey(
model_name=d["model_name"],
world_size=d["world_size"],
worker_id=d["worker_id"],
chunk_hash=d["chunk_hash"],
)
@dataclass(order=True)
class LayerMooncakeEngineKey(MooncakeEngineKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
class ChunkedTokenDatabase():
def __init__(
self,
metadata: MooncakeEngineMetadata,
):
self.metadata = metadata
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return MooncakeEngineKey(
self.metadata.model_name,
self.metadata.world_size,
self.metadata.worker_id,
chunk_hash,
)
def _hash(
self,
tokens: Union[torch.Tensor, List[int]],
prefix_hash: str,
) -> str:
# TODO: change it to a more efficient hash function
if isinstance(tokens, torch.Tensor):
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
elif isinstance(tokens, list):
tokens_bytes = array.array("I", tokens).tobytes()
return hashlib.sha256(prefix_hash.encode("ascii") +
tokens_bytes).hexdigest()
def _chunk_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
) -> Iterable[Union[torch.Tensor, List[int]]]:
"""
Chunk the tokens into chunks of size self.metadata.block_size.
:param tokens: the input tokens, with shape [seq_len]
device: the target device after chunking
:return: a generator of chunks of tokens, each with
shape [metadata.block_size]
"""
for i in range(0, len(tokens), self.metadata.block_size):
yield tokens[i:i + self.metadata.block_size]
def _prefix_hash(
self,
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
) -> Iterable[str]:
prefix_hash = ''
for token_chunk in token_chunks:
prefix_hash = self._hash(token_chunk, prefix_hash)
yield prefix_hash
def process_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
mask: Optional[torch.Tensor] = None,
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
num_falses = mask.numel() - mask.long().sum().item()
else:
num_falses = 0
if num_falses % self.metadata.block_size != 0:
raise ValueError(
"The number of Falses in the mask is not a multiple of the chunk size."
)
total_len = len(tokens)
token_chunks = self._chunk_tokens(tokens)
prefix_hashes = self._prefix_hash(token_chunks)
start_idx = 0
for chunk_id, hash_val in enumerate(prefix_hashes):
start_idx = chunk_id * self.metadata.block_size
end_idx = min(start_idx + self.metadata.block_size, total_len)
if start_idx < num_falses:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in mooncake
mooncake_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
@dataclass
class SaveSpec:
# Skip already saved tokens
skip_leading_tokens: int
# Whether the scheduler allow us to save the tokens
can_save: bool
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_ids: list[int]
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
# list[list[int]]
# Need to check the type of request.block_ids
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_ids.extend(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_ids: torch.Tensor
block_ids: list[int]
# # Slot mapping if exchange for block_id
# slot_mapping: torch.Tensor
# Skip save or not
save_spec: Optional[SaveSpec] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_ids = tracker.token_ids
input_token_len = len(input_token_ids)
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
skip_leading_tokens = tracker.num_saved_tokens
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
# Calculate the token ids and slot mappings for load and save
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.mooncake_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_ids=token_ids,
block_ids=tracker.allocated_block_ids,
save_spec=save_spec,
load_spec=load_spec,
is_last_chunk=is_last_chunk,
)
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerMooncakeEngineKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
use_ascend_direct: bool
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get("global_segment_size", 3355443200),
local_buffer_size=config.get("local_buffer_size", 1073741824),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)

View File

@@ -0,0 +1,293 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
import torch
from vllm.utils import logger
from vllm_npu.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
class KVTransferThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.tp_rank = tp_rank
self.tp_size = tp_size
self.m_store = m_store
self.ready_event = ready_event
self.kv_caches_base_addr = local_kv_caches_base_addr
self.block_len = block_len
self.token_database = token_database
self.block_size = block_size
self.done_task_lock = threading.Lock()
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def add_request(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
is_last_chunk: Optional[bool] = None,
current_event: Optional[torch.npu.Event] = None,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"tokens": tokens,
"block_ids": block_ids,
"mask": mask,
"is_last_chunk": is_last_chunk,
"current_event": current_event,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
pass
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheSendingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
current_event = req_meta["current_event"]
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
if key_list:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
if current_event is not None:
current_event.synchronize()
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
if current_event is not None:
current_event.synchronize()
self.m_store.put(key, addr, size)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.get(key, addr, size)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
num_layers: int):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.put(key, addr, size)
if req_meta.layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
get_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.get(key, addr, size)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -0,0 +1,639 @@
# Standard
import math
import threading
import time
from typing import Generator, List, Optional, Union
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.utils import get_kv_cache_torch_dtype, logger
from vllm_npu.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
MooncakeEngineMetadata)
from vllm_npu.distributed.mooncake.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
class MooncakeEngine:
#The main class for the cache engine.
def __init__(
self,
vllm_config: VllmConfig,
use_layerwize: bool,
):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"register_buffer", False)
self.block_size = vllm_config.cache_config.block_size
self.current_layer = 0
# self.use_mla = first_kv_cache_tuple[0].size(
# -1) != first_kv_cache_tuple[1].size(-1)
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
num_kv_head = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
kv_dtype = get_kv_cache_torch_dtype(
vllm_config.cache_config.cache_dtype, model_config.dtype)
self.hidden_dim_size = num_kv_head * head_size
if self.use_mla:
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
else:
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
head_size)
self.metadata = MooncakeEngineMetadata(
model_config.model,
parallel_config.world_size,
parallel_config.rank,
kv_dtype,
kv_shape,
self.block_size,
self.use_mla,
)
self.token_database = ChunkedTokenDatabase(self.metadata)
self.m_store = Mooncakestore(parallel_config)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
self.kv_caches_base_addr = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.register_buffer:
region_len = self.num_blocks * self.block_len[i % 2]
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.register_buffer:
region_len = self.num_blocks * self.block_len[0]
self._register(base_addr, region_len)
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending,
self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
try:
self.m_store.register_buffer(ptr, length)
except Exception as e:
raise RuntimeError(
f"Mooncake memory registration failed. Error is: {e}")
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
self.current_layer = 0
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
continue
tokens = request.token_ids
req_id = request.req_id
if (load_spec.mooncake_cached_tokens % self.block_size
!= 0) and (load_spec.mooncake_cached_tokens
== tokens.shape[0] - 1):
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
else:
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
masked_token_count = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
token_mask = torch.ones_like(tokens, dtype=torch.bool)
token_mask[:masked_token_count] = False
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
tokens,
request.block_ids,
token_mask,
)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
tokens,
request.block_ids,
token_mask,
)
else:
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, block_id = self.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list,
blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, _ = self.prepare_value(
start, end, request.block_ids)
self.m_store.get(key, addr, size)
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def wait_for_layer_load(self) -> None:
"""MooncakeConnector does not do layerwise saving."""
for layerwise_retriever in self.layerwise_retrievers:
ret_token_mask = next(layerwise_retriever)
if self.current_layer == self.num_layers - 1:
assert ret_token_mask is not None
num_retrieved_tokens = ret_token_mask.sum().item()
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: MooncakeConnectorMetadata) -> None:
"""MooncakeConnector does not save explicitly."""
if self.current_layer == 0:
self.layerwise_storers = []
current_event = None
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
token_ids = request.token_ids
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_ids,
mask=store_mask,
block_ids=request.block_ids,
)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
next(layerwise_storer)
except Exception:
raise
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
"""MooncakeConnector does not save explicitly."""
current_event = None
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
token_ids = request.token_ids
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_ids,
request.block_ids,
store_mask,
request.is_last_chunk,
current_event,
)
def retrieve_layer(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the KV transfer which
will be passed into the npu_transfer.
return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
if mask is not None:
num_required_tokens = torch.sum(mask).item()
else:
num_required_tokens = len(tokens)
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
starts = []
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
tokens, mask):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer)
ret_mask[start:end] = True
if keys:
# Transpose the keys into layer major format
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False
yield None
else:
# If no cache are found, we still need to yield to avoid
# `StopIteration`
for layer_id in range(self.num_layers):
yield None
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {len(tokens)} tokens")
yield ret_mask
def store_layer(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
return: A generator that yields None. In the first iteration, the
generator allocates the memory objects for all layers and moves
the KV cache of the first layer from GPU to CPU. In the next
iterations, it moves the KV cache of layer i from GPU to the memory
objects (on CPU) and puts the memory objects of layer i-1 to the
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
if mask is not None:
num_stored_tokens = torch.sum(mask).item()
else:
num_stored_tokens = len(tokens)
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num]
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.load_async else set())
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
self.tp_rank)
return done_sending, done_recving
def wait_layer_transfer_finish(self):
time.sleep(10)
pass
def lookup(
self,
tokens: Union[torch.Tensor, List[int]],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
keys.append(key.to_string())
starts.append(start)
res = self.m_store.batch_exists(
keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def lookup_scheduler(
self,
tokens: Union[torch.Tensor, List[int]],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
keys.append(key.to_string())
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, self.tp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@0", f"@{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.batch_exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
multi_tp_values = [
res[i * num_block:(i + 1) *
num_block] # type: ignore[index]
for i in range(self.tp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
except ValueError:
return -1
def close(self) -> None:
"""Close the cache engine and free all the resources"""
self.m_store.close()

View File

@@ -0,0 +1,126 @@
# Standard
import os
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.utils import get_ip, logger
from vllm_npu.distributed.mooncake.config_data import MooncakeEngineKey
from vllm_npu.distributed.mooncake.transfer_engine import get_global_te
from .config_data import MooncakeStoreConfig
METADATA_BYTES_LEN = 24
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
class Mooncakestore():
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
tp_rank = get_tensor_model_parallel_rank()
tp_size = parallel_config.tensor_parallel_size
dp_rank = parallel_config.data_parallel_rank_local
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
if not all_device_ids:
device_ids_list = list(
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
else:
device_ids_list = list(map(int, all_device_ids.split(',')))
assert len(device_ids_list) > tp_rank
device_id = device_ids_list[tp_rank]
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
":npu_" + str(device_id)
ret = self.store.setup(local_hostname, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address)
else:
local_hostname = get_ip()
transfer_engine = get_global_te(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def exists(self, key: MooncakeEngineKey) -> bool:
return self.store.is_exist(key.to_string()) == 1
def batch_exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def register_buffer(self, ptr, length):
return self.store.register_buffer(ptr, length)
def get_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
res = self.store.batch_get_into_multi_buffers(
keys, addrs, sizes, True)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}. {e}")
def put_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(
keys, addrs, sizes, config)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
expect_res = sum(size)
key_str = key.to_string()
try:
res = self.store.batch_get_into_ascend(key_str, addr, size)
if res[0] != expect_res:
logger.error(f"Failed to get key: [{key_str}] .")
except Exception:
logger.error(f"Failed to get key: [{key_str}] .")
return res
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
key_str = key.to_string()
try:
ret = self.store.batch_put_from_ascend(key_str, addr, size)
if ret[0] != 0:
logger.error(f"Failed to put key {key_str}.")
except Exception:
logger.error(f"Failed to put key {key_str}.")
return ret
def close(self):
self.store.close()
logger.info("Closed the mooncake store connection")

View File

@@ -0,0 +1,494 @@
import threading
from typing import Any, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.utils import logger, make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm_npu.distributed.mooncake.config_data import (
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
from vllm_npu.distributed.mooncake.mooncake_engine import MooncakeEngine
class MooncakeConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
vllm_config, self.use_layerwise)
else:
self.connector_worker = MooncakeEngine(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0 and self.kv_role != "kv_consumer":
self.lookup_server = MooncakeLookupServer(
self.connector_worker, vllm_config, self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._get_connector_metadata(),
MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeStoreConnector does not do layerwise saving."""
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""MooncakeStoreConnector does not save explicitly."""
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
"""MooncakeStoreConnector does not save explicitly."""
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
self.connector_worker.wait_layer_transfer_finish()
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:
sended_and_finished.add(item)
self.sended_but_unfinished_reqs.remove(item)
for item in done_sending:
if item in meta.unfinished_request_ids:
self.sended_but_unfinished_reqs.add(item)
else:
sended_and_finished.add(item)
return sended_and_finished, done_recving
def get_zmq_rpc_path_mooncake(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
"mooncake_rpc_port", 0)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
class MooncakeStoreConnectorV1Scheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.client = MooncakeLookupClient(
vllm_config) if self.kv_role != "kv_consumer" else None
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
# request_id -> (vllm cached tokes, mooncake cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self._block_size = vllm_config.cache_config.block_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True))
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set()
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Check for external KV cache hit.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
return 0, False
if self._discard_partial_chunks:
token_block_end = len(request.prompt_token_ids
) // self._block_size * self._block_size
token_ids = torch.tensor(
request.prompt_token_ids[:token_block_end])
else:
token_ids = torch.tensor(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup( # type: ignore[union-attr]
token_ids)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.info(
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
need_to_allocate,
)
if need_to_allocate <= 0:
return 0, False
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
mooncake_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, self.load_async
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after temporary buffer alloc.
For SharedStorageConnector, update _request_needs_load
if the CacheManager this allocated blocks for us.
"""
local_block_ids = []
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request,
local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return
return
if num_external_tokens == 0:
# No need to load anything
self.load_specs[request.request_id].can_load = False
return
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].mooncake_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
except the `kv_connector_metadata` field.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = self.kv_role == "kv_consumer"
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tracker = RequestTracker.from_new_request(
request, num_tokens_to_compute)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
cached_reqs = scheduler_output.scheduled_cached_reqs
if isinstance(cached_reqs, list) and not force_skip_save:
for i, req in enumerate(cached_reqs):
request_tracker = self._request_trackers[req.req_id]
request_tracker.update(req.new_token_ids, req.new_block_ids)
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(req.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
elif not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = len(request_tracker.token_ids)
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if len(request_tracker.token_ids) > len(
request.prompt_token_ids):
continue
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
request_ids = [
req.req_id for req in scheduler_output.scheduled_new_reqs
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.mooncake_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=None,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer":
return False, None
tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0:
return False, None
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(block_ids), request.request_id)
return delay_free_blocks, None
class MooncakeLookupClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REQ, # type: ignore[attr-defined]
bind=False,
)
def lookup(self, token_ids: torch.Tensor) -> int:
request = self.encoder.encode(token_ids)
self.socket.send_multipart(request, copy=False)
resp = self.socket.recv()
result = int.from_bytes(resp, "big")
return result
def close(self):
self.socket.close(linger=0)
class MooncakeLookupServer:
def __init__(
self,
mooncake_engine: MooncakeEngine,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.mooncake_engine = mooncake_engine
self.running = True
def process_request():
while self.running:
frames = self.socket.recv_multipart(copy=False)
token_ids = self.decoder.decode(frames)
result = self.mooncake_engine.lookup_scheduler(
token_ids, use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!

View File

@@ -0,0 +1,38 @@
import ipaddress
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
_global_te = None
_global_te_lock = threading.Lock()
def get_global_te(hostname: str, device_name: Optional[str]):
try:
ip = ipaddress.ip_address(hostname)
if isinstance(ip, ipaddress.IPv6Address):
raise RuntimeError(
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
)
except ValueError:
pass
global _global_te
if _global_te is None:
with _global_te_lock:
# Double-Checked Locking
if _global_te is None:
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = transfer_engine.initialize(hostname,
"P2PHANDSHAKE",
"ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
_global_te = transfer_engine
return _global_te