mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
293
vllm_npu/distributed/mooncake/kv_transfer.py
Normal file
293
vllm_npu/distributed/mooncake/kv_transfer.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
|
||||
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches_base_addr = local_kv_caches_base_addr
|
||||
self.block_len = block_len
|
||||
self.token_database = token_database
|
||||
self.block_size = block_size
|
||||
self.done_task_lock = threading.Lock()
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
current_event: Optional[torch.npu.Event] = None,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"tokens": tokens,
|
||||
"block_ids": block_ids,
|
||||
"mask": mask,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
"current_event": current_event,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
current_event = req_meta["current_event"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
if key_list:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(key, addr, size)
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
num_layers: int):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.put(key, addr, size)
|
||||
if req_meta.layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
Reference in New Issue
Block a user