mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
639
vllm_npu/distributed/mooncake/mooncake_engine.py
Normal file
639
vllm_npu/distributed/mooncake/mooncake_engine.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# Standard
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import get_kv_cache_torch_dtype, logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
|
||||
MooncakeEngineMetadata)
|
||||
from vllm_npu.distributed.mooncake.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class MooncakeEngine:
|
||||
#The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_layerwize: bool,
|
||||
):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla):
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = parallel_config.rank
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"register_buffer", False)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.current_layer = 0
|
||||
# self.use_mla = first_kv_cache_tuple[0].size(
|
||||
# -1) != first_kv_cache_tuple[1].size(-1)
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_dtype = get_kv_cache_torch_dtype(
|
||||
vllm_config.cache_config.cache_dtype, model_config.dtype)
|
||||
self.hidden_dim_size = num_kv_head * head_size
|
||||
if self.use_mla:
|
||||
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
|
||||
else:
|
||||
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
|
||||
head_size)
|
||||
self.metadata = MooncakeEngineMetadata(
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
self.block_size,
|
||||
self.use_mla,
|
||||
)
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata)
|
||||
|
||||
self.m_store = Mooncakestore(parallel_config)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
self._register(base_addr, region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
self._register(base_addr, region_len)
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database, self.block_len,
|
||||
self.block_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _register(self, ptr, length):
|
||||
logger.debug(
|
||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
||||
try:
|
||||
self.m_store.register_buffer(ptr, length)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Mooncake memory registration failed. Error is: {e}")
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
self.current_layer = 0
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
tokens = request.token_ids
|
||||
req_id = request.req_id
|
||||
if (load_spec.mooncake_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.mooncake_cached_tokens
|
||||
== tokens.shape[0] - 1):
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
|
||||
else:
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
|
||||
masked_token_count = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
token_mask = torch.ones_like(tokens, dtype=torch.bool)
|
||||
token_mask[:masked_token_count] = False
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
else:
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list,
|
||||
blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, _ = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
for layerwise_retriever in self.layerwise_retrievers:
|
||||
ret_token_mask = next(layerwise_retriever)
|
||||
if self.current_layer == self.num_layers - 1:
|
||||
assert ret_token_mask is not None
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: MooncakeConnectorMetadata) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_ids,
|
||||
mask=store_mask,
|
||||
block_ids=request.block_ids,
|
||||
)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
next(layerwise_storer)
|
||||
except Exception:
|
||||
raise
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_ids,
|
||||
request.block_ids,
|
||||
store_mask,
|
||||
request.is_last_chunk,
|
||||
current_event,
|
||||
)
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the KV transfer which
|
||||
will be passed into the npu_transfer.
|
||||
|
||||
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_required_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_required_tokens = len(tokens)
|
||||
|
||||
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer)
|
||||
ret_mask[start:end] = True
|
||||
|
||||
if keys:
|
||||
# Transpose the keys into layer major format
|
||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
if not first_flag:
|
||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
first_flag = False
|
||||
yield None
|
||||
else:
|
||||
# If no cache are found, we still need to yield to avoid
|
||||
# `StopIteration`
|
||||
for layer_id in range(self.num_layers):
|
||||
yield None
|
||||
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {len(tokens)} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the storage backend which
|
||||
will be passed into the gpu_connector.
|
||||
|
||||
return: A generator that yields None. In the first iteration, the
|
||||
generator allocates the memory objects for all layers and moves
|
||||
the KV cache of the first layer from GPU to CPU. In the next
|
||||
iterations, it moves the KV cache of layer i from GPU to the memory
|
||||
objects (on CPU) and puts the memory objects of layer i-1 to the
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_stored_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_stored_tokens = len(tokens)
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
||||
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.load_async else set())
|
||||
|
||||
logger.debug(
|
||||
"Number of completed KV cache send requests: %d, receive "
|
||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def wait_layer_transfer_finish(self):
|
||||
time.sleep(10)
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
res = self.m_store.batch_exists(
|
||||
keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def lookup_scheduler(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, self.tp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@0", f"@{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
res = self.m_store.batch_exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) *
|
||||
num_block] # type: ignore[index]
|
||||
for i in range(self.tp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the cache engine and free all the resources"""
|
||||
self.m_store.close()
|
||||
Reference in New Issue
Block a user