mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
449
vllm_npu/distributed/mooncake/config_data.py
Normal file
449
vllm_npu/distributed/mooncake/config_data.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import array
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import cdiv, logger
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeEngineMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" world size when running under a distributed setting """
|
||||
world_size: int
|
||||
""" worker id when running under a distributed setting """
|
||||
worker_id: int
|
||||
""" the format of kv tensors """
|
||||
kv_dtype: torch.dtype
|
||||
""" the shape of kv tensors """
|
||||
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
|
||||
kv_shape: tuple[int, int, int, int, int]
|
||||
block_size: int = 128
|
||||
""" whether use MLA"""
|
||||
use_mla: bool = False
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class MooncakeEngineKey:
|
||||
model_name: str
|
||||
world_size: int
|
||||
worker_id: int
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}")
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerMooncakeEngineKey(
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
def to_dict(self):
|
||||
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
|
||||
return {
|
||||
"__type__": "CacheEngineKey",
|
||||
"model_name": self.model_name,
|
||||
"world_size": self.world_size,
|
||||
"worker_id": self.worker_id,
|
||||
"chunk_hash": self.chunk_hash,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return MooncakeEngineKey(
|
||||
model_name=d["model_name"],
|
||||
world_size=d["world_size"],
|
||||
worker_id=d["worker_id"],
|
||||
chunk_hash=d["chunk_hash"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerMooncakeEngineKey(MooncakeEngineKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: MooncakeEngineMetadata,
|
||||
):
|
||||
self.metadata = metadata
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return MooncakeEngineKey(
|
||||
self.metadata.model_name,
|
||||
self.metadata.world_size,
|
||||
self.metadata.worker_id,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def _hash(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
prefix_hash: str,
|
||||
) -> str:
|
||||
# TODO: change it to a more efficient hash function
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
|
||||
elif isinstance(tokens, list):
|
||||
tokens_bytes = array.array("I", tokens).tobytes()
|
||||
return hashlib.sha256(prefix_hash.encode("ascii") +
|
||||
tokens_bytes).hexdigest()
|
||||
|
||||
def _chunk_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
) -> Iterable[Union[torch.Tensor, List[int]]]:
|
||||
"""
|
||||
Chunk the tokens into chunks of size self.metadata.block_size.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
device: the target device after chunking
|
||||
|
||||
:return: a generator of chunks of tokens, each with
|
||||
shape [metadata.block_size]
|
||||
"""
|
||||
for i in range(0, len(tokens), self.metadata.block_size):
|
||||
yield tokens[i:i + self.metadata.block_size]
|
||||
|
||||
def _prefix_hash(
|
||||
self,
|
||||
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
|
||||
) -> Iterable[str]:
|
||||
prefix_hash = ''
|
||||
for token_chunk in token_chunks:
|
||||
prefix_hash = self._hash(token_chunk, prefix_hash)
|
||||
yield prefix_hash
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if mask is not None:
|
||||
num_falses = mask.numel() - mask.long().sum().item()
|
||||
else:
|
||||
num_falses = 0
|
||||
|
||||
if num_falses % self.metadata.block_size != 0:
|
||||
raise ValueError(
|
||||
"The number of Falses in the mask is not a multiple of the chunk size."
|
||||
)
|
||||
total_len = len(tokens)
|
||||
|
||||
token_chunks = self._chunk_tokens(tokens)
|
||||
prefix_hashes = self._prefix_hash(token_chunks)
|
||||
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(prefix_hashes):
|
||||
start_idx = chunk_id * self.metadata.block_size
|
||||
end_idx = min(start_idx + self.metadata.block_size, total_len)
|
||||
if start_idx < num_falses:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in mooncake
|
||||
mooncake_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveSpec:
|
||||
# Skip already saved tokens
|
||||
skip_leading_tokens: int
|
||||
# Whether the scheduler allow us to save the tokens
|
||||
can_save: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_ids: list[int]
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
|
||||
# list[list[int]]
|
||||
# Need to check the type of request.block_ids
|
||||
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_ids.extend(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
|
||||
block_ids: list[int]
|
||||
# # Slot mapping if exchange for block_id
|
||||
# slot_mapping: torch.Tensor
|
||||
# Skip save or not
|
||||
save_spec: Optional[SaveSpec] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_ids = tracker.token_ids
|
||||
input_token_len = len(input_token_ids)
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
skip_leading_tokens = tracker.num_saved_tokens
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
|
||||
|
||||
# Calculate the token ids and slot mappings for load and save
|
||||
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
|
||||
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.mooncake_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
save_spec=save_spec,
|
||||
load_spec=load_spec,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerMooncakeEngineKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
use_ascend_direct: bool
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size", 3355443200),
|
||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
Reference in New Issue
Block a user