# Standard import math import threading import time from typing import Generator, List, Optional, Union # Third Party import torch from vllm.config import VllmConfig from vllm.utils import get_kv_cache_torch_dtype, logger from vllm_npu.distributed.mooncake.config_data import ( ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, MooncakeEngineMetadata) from vllm_npu.distributed.mooncake.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore class MooncakeEngine: #The main class for the cache engine. def __init__( self, vllm_config: VllmConfig, use_layerwize: bool, ): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.use_mla = False if (hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize self.tp_rank = parallel_config.rank self.tp_size = parallel_config.tensor_parallel_size self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "register_buffer", False) self.block_size = vllm_config.cache_config.block_size self.current_layer = 0 # self.use_mla = first_kv_cache_tuple[0].size( # -1) != first_kv_cache_tuple[1].size(-1) self.num_layers = model_config.get_num_layers(parallel_config) self.block_size = vllm_config.cache_config.block_size num_kv_head = model_config.get_num_kv_heads(parallel_config) head_size = model_config.get_head_size() kv_dtype = get_kv_cache_torch_dtype( vllm_config.cache_config.cache_dtype, model_config.dtype) self.hidden_dim_size = num_kv_head * head_size if self.use_mla: kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) else: kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, head_size) self.metadata = MooncakeEngineMetadata( model_config.model, parallel_config.world_size, parallel_config.rank, kv_dtype, kv_shape, self.block_size, self.use_mla, ) self.token_database = ChunkedTokenDatabase(self.metadata) self.m_store = Mooncakestore(parallel_config) self.kv_send_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] # TODO(tms): Find a more robust way to detect and handle MLA if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), first_kv_cache[1].element_size() * math.prod(block_shape_pe) ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] kv_elem_size = first_kv_cache.element_size() block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape) self.kv_caches = kv_caches self.kv_caches_base_addr = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) if self.register_buffer: region_len = self.num_blocks * self.block_len[i % 2] self._register(base_addr, region_len) else: cache_list = [cache_or_caches ] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) if self.register_buffer: region_len = self.num_blocks * self.block_len[0] self._register(base_addr, region_len) if self.use_layerwise: self.get_event = threading.Event() if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( self.tp_rank, self.tp_size, self.m_store, self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending, self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.tp_rank, self.tp_size, self.m_store, self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( self.tp_rank, self.tp_size, self.m_store, self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.tp_rank, self.tp_size, self.m_store, self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event) self.kv_recv_thread.start() ready_event.wait() def _register(self, ptr, length): logger.debug( "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " "block_lens=%s", ptr, length, self.num_blocks, self.block_len) try: self.m_store.register_buffer(ptr, length) except Exception as e: raise RuntimeError( f"Mooncake memory registration failed. Error is: {e}") def start_load_kv(self, metadata: MooncakeConnectorMetadata): self.current_layer = 0 self.layerwise_retrievers = [] for request in metadata.requests: load_spec = request.load_spec if load_spec is None or not load_spec.can_load: #load =0 continue tokens = request.token_ids req_id = request.req_id if (load_spec.mooncake_cached_tokens % self.block_size != 0) and (load_spec.mooncake_cached_tokens == tokens.shape[0] - 1): tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] else: tokens = tokens[:request.load_spec.mooncake_cached_tokens] masked_token_count = (request.load_spec.vllm_cached_tokens // self.block_size * self.block_size) token_mask = torch.ones_like(tokens, dtype=torch.bool) token_mask[:masked_token_count] = False if self.use_layerwise: layerwise_retriever = self.retrieve_layer( req_id, tokens, request.block_ids, token_mask, ) next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) else: if self.load_async: self.kv_recv_thread.add_request( # type: ignore[union-attr] req_id, tokens, request.block_ids, token_mask, ) else: if self.m_store.config.use_ascend_direct: addr_list = [] size_list = [] key_list = [] blockIds = [] for start, end, key in self.token_database.process_tokens( tokens, token_mask): addr, size, block_id = self.prepare_value( start, end, request.block_ids) key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) blockIds.append(block_id) self.m_store.get_batch(key_list, addr_list, size_list, blockIds) else: for start, end, key in self.token_database.process_tokens( tokens, token_mask): addr, size, _ = self.prepare_value( start, end, request.block_ids) self.m_store.get(key, addr, size) def prepare_value(self, start: int, end: int, block_ids: list[int]): addr_list = [] size_list = [] block_id = block_ids[start // self.block_size] for index, base_addr in enumerate(self.kv_caches_base_addr): block_len = (self.block_len[index % 2] if self.use_mla else self.block_len[0]) addr = base_addr + block_id * block_len length = int(block_len / self.block_size * (end - start)) addr_list.append(addr) size_list.append(length) return addr_list, size_list, block_id def wait_for_layer_load(self) -> None: """MooncakeConnector does not do layerwise saving.""" for layerwise_retriever in self.layerwise_retrievers: ret_token_mask = next(layerwise_retriever) if self.current_layer == self.num_layers - 1: assert ret_token_mask is not None num_retrieved_tokens = ret_token_mask.sum().item() logger.info(f"Retrieved {num_retrieved_tokens} tokens") def save_kv_layer(self, connector_metadata: MooncakeConnectorMetadata) -> None: """MooncakeConnector does not save explicitly.""" if self.current_layer == 0: self.layerwise_storers = [] current_event = None for request in connector_metadata.requests: save_spec = request.save_spec if save_spec is None or not save_spec.can_save: continue current_event = torch.npu.Event() current_event.record() break for request in connector_metadata.requests: save_spec = request.save_spec if save_spec is None or not save_spec.can_save: continue token_ids = request.token_ids req_id = request.req_id assert isinstance(token_ids, torch.Tensor) assert token_ids.is_cpu # TODO: whether need to remov saveThread # no lookup, skipmask skip_leading_tokens = max( self.lookup(token_ids, self.use_layerwise), save_spec.skip_leading_tokens, ) if skip_leading_tokens == len(token_ids): if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request skip_leading_tokens = (skip_leading_tokens // self.block_size * self.block_size) store_mask = torch.ones_like(token_ids, dtype=torch.bool) store_mask[:skip_leading_tokens] = False logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", len(token_ids) - skip_leading_tokens, len(token_ids), skip_leading_tokens, request.req_id, ) layerwise_storer = self.store_layer( req_id, token_ids, mask=store_mask, block_ids=request.block_ids, ) self.layerwise_storers.append(layerwise_storer) for layerwise_storer in self.layerwise_storers: try: next(layerwise_storer) except Exception: raise self.current_layer = self.current_layer + 1 def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): """MooncakeConnector does not save explicitly.""" current_event = None for request in connector_metadata.requests: save_spec = request.save_spec if save_spec is None or not save_spec.can_save: continue current_event = torch.npu.Event() current_event.record() break for request in connector_metadata.requests: save_spec = request.save_spec if save_spec is None or not save_spec.can_save: continue token_ids = request.token_ids req_id = request.req_id assert isinstance(token_ids, torch.Tensor) assert token_ids.is_cpu skip_leading_tokens = max( self.lookup(token_ids, self.use_layerwise), save_spec.skip_leading_tokens, ) if skip_leading_tokens == len(token_ids): if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request skip_leading_tokens = (skip_leading_tokens // self.block_size * self.block_size) store_mask = torch.ones_like(token_ids, dtype=torch.bool) store_mask[:skip_leading_tokens] = False logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", len(token_ids) - skip_leading_tokens, len(token_ids), skip_leading_tokens, request.req_id, ) self.kv_send_thread.add_request( # type: ignore[union-attr] req_id, token_ids, request.block_ids, store_mask, request.is_last_chunk, current_event, ) def retrieve_layer( self, req_id: str, tokens: torch.Tensor, block_ids: list[int], mask: Optional[torch.Tensor] = None, ) -> Generator[Optional[torch.Tensor], None, None]: """ Retrieve the KV cache in a layerwise manner. :param torch.Tensor tokens: The tokens of the corresponding KV caches. :param Optional[torch.Tensor] mask: The mask for the tokens. Should have the same length as tokens. And the mask should ALWAYS be like FFFFFTTTTTTT, where True means the tokens needs to be matched. :param **kwargs: The additional arguments for the KV transfer which will be passed into the npu_transfer. return: A generator that yields Optional[torch.Tensor]. The tensor will be the boolean mask indicating which tokens are retrieved and will only be returned in the last iteration. """ if mask is not None: num_required_tokens = torch.sum(mask).item() else: num_required_tokens = len(tokens) ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") starts = [] ends = [] keys = [] first_flag = True for start, end, key in self.token_database.process_tokens( tokens, mask): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) keys.append(keys_multi_layer) ret_mask[start:end] = True if keys: # Transpose the keys into layer major format keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] for layer_id, keys_multi_chunk in enumerate(keys): if not first_flag: is_finish = self.get_event.wait(timeout=3) #try---cache if not is_finish: logger.info("Layerwise get failed") self.get_event.clear() req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] first_flag = False yield None else: # If no cache are found, we still need to yield to avoid # `StopIteration` for layer_id in range(self.num_layers): yield None retrieved_tokens = torch.sum(ret_mask) logger.debug(f"Retrieved {retrieved_tokens} " f"out of {num_required_tokens} " f"out of total {len(tokens)} tokens") yield ret_mask def store_layer( self, req_id: str, tokens: torch.Tensor, block_ids: list[int], mask: Optional[torch.Tensor] = None, ) -> Generator[None, None, None]: """ Store the KV cache in a layerwise manner. :param torch.Tensor tokens: The tokens of the corresponding KV caches. :param Optional[torch.Tensor] mask: The mask for the tokens. Should have the same length as tokens. And the mask should ALWAYS be like FFFFFTTTTTTT, where True means the tokens needs to be matched. :param **kwargs: The additional arguments for the storage backend which will be passed into the gpu_connector. return: A generator that yields None. In the first iteration, the generator allocates the memory objects for all layers and moves the KV cache of the first layer from GPU to CPU. In the next iterations, it moves the KV cache of layer i from GPU to the memory objects (on CPU) and puts the memory objects of layer i-1 to the storage backends. In the last iteration, it puts the memory objects of the last layer to the storage backends. """ if mask is not None: num_stored_tokens = torch.sum(mask).item() else: num_stored_tokens = len(tokens) starts = [] ends = [] keys = [] for start, end, key in self.token_database.process_tokens( tokens, mask): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) keys.append(keys_multi_layer) #[block_num,layer_num] if keys: keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] yield else: for layer_id in range(self.num_layers): yield logger.debug( f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( self.kv_send_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) done_recving = ( self.kv_recv_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.load_async else set()) logger.debug( "Number of completed KV cache send requests: %d, receive " "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), self.tp_rank) return done_sending, done_recving def wait_layer_transfer_finish(self): time.sleep(10) pass def lookup( self, tokens: Union[torch.Tensor, List[int]], use_layerwise: bool, ) -> int: """ Checks the existence of KV cache of the tokens from the cache engine. :param tokens: the input tokens, with shape [seq_len] :return: An int indicating how many prefix tokens are cached. """ end = 0 keys = [] try: if use_layerwise: for start, end, key in self.token_database.process_tokens( tokens): keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) # batch is_exists ress = self.m_store.batch_exists(keys) res = 1 for value in ress: if value != 1: res = 0 break if res == 1: continue else: return start else: starts = [] for start, end, key in self.token_database.process_tokens( tokens): keys.append(key.to_string()) starts.append(start) res = self.m_store.batch_exists( keys) # type: ignore[assignment] for index, value in enumerate(res): # type: ignore[arg-type] if value != 1: return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") return start return end def lookup_scheduler( self, tokens: Union[torch.Tensor, List[int]], use_layerwise: bool, ) -> int: """ Checks the existence of KV cache of the tokens from the cache engine. :param tokens: the input tokens, with shape [seq_len] :return: An int indicating how many prefix tokens are cached. """ end = 0 keys = [] try: if use_layerwise: for start, end, key in self.token_database.process_tokens( tokens): keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) # batch is_exists ress = self.m_store.batch_exists(keys) res = 1 for value in ress: if value != 1: res = 0 break if res == 1: continue else: return start else: starts = [] for start, end, key in self.token_database.process_tokens( tokens): keys.append(key.to_string()) starts.append(start) multi_tp_keys = keys[:] for i in range(1, self.tp_size): for item in keys: new_str = item.replace( # type: ignore[attr-defined] "@0", f"@{i}", 1) multi_tp_keys.append(new_str) res = self.m_store.batch_exists( multi_tp_keys) # type: ignore[assignment] num_block = len(keys) multi_tp_values = [ res[i * num_block:(i + 1) * num_block] # type: ignore[index] for i in range(self.tp_size) ] index = self.find_min_first_non_one_index(multi_tp_values) if index != -1: return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") return start return end def find_min_first_non_one_index(self, arr): try: return min(idx for row in arr for idx, val in enumerate(row) if val != 1) except ValueError: return -1 def close(self) -> None: """Close the cache engine and free all the resources""" self.m_store.close()