mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
@@ -1 +1,40 @@
|
||||
"""Ascend NPU distributed communication (HCCL)."""
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
|
||||
def register_connector():
|
||||
KVConnectorFactory.register_connector(
|
||||
"LLMDataDistCMgrConnector",
|
||||
"vllm_npu.distributed.llmdatadist_c_mgr_connector",
|
||||
"LLMDataDistCMgrConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorV1", "vllm_npu.distributed.mooncake_connector",
|
||||
"MooncakeConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorStoreV1",
|
||||
"vllm_npu.distributed.mooncake.mooncake_store_connector_v1",
|
||||
"MooncakeConnectorV1")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeLayerwiseConnector",
|
||||
"vllm_npu.distributed.mooncake_layerwise_connector",
|
||||
"MooncakeLayerwiseConnector")
|
||||
|
||||
@@ -1,42 +1,46 @@
|
||||
"""
|
||||
NPUCommunicator — HCCL-based device communicator for Ascend NPU.
|
||||
|
||||
Extends ``DeviceCommunicatorBase`` with NPU-specific collective
|
||||
operations using the HCCL backend.
|
||||
"""
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase,
|
||||
)
|
||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
DeviceCommunicatorBase
|
||||
|
||||
|
||||
class NPUCommunicator(DeviceCommunicatorBase):
|
||||
"""Device communicator for Ascend NPU using HCCL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
def __init__(self,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
import torch_npu # noqa: F401
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
def all_to_all(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""All-to-all communication for NPU tensors."""
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if scatter_dim < 0:
|
||||
scatter_dim += input_.dim()
|
||||
if gather_dim < 0:
|
||||
@@ -53,22 +57,17 @@ class NPUCommunicator(DeviceCommunicatorBase):
|
||||
tensor_shape = list(tensor_shape_base)
|
||||
tensor_shape[gather_dim] = gather_sizes[i]
|
||||
output_list.append(
|
||||
torch.empty(
|
||||
tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device,
|
||||
)
|
||||
)
|
||||
torch.empty(tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device))
|
||||
|
||||
else:
|
||||
input_list = [
|
||||
t.contiguous()
|
||||
for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim
|
||||
)
|
||||
t.contiguous() for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim)
|
||||
]
|
||||
output_list = [
|
||||
torch.empty_like(input_list[i])
|
||||
for i in range(self.world_size)
|
||||
torch.empty_like(input_list[i]) for i in range(self.world_size)
|
||||
]
|
||||
|
||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||
|
||||
471
vllm_npu/distributed/cpu_offload_connector.py
Normal file
471
vllm_npu/distributed/cpu_offload_connector.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MLAAttentionSpec)
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
from vllm_npu.distributed.cpu_offload_manager.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
num_scheduled_tokens: int
|
||||
num_computed_tokens: int
|
||||
num_gpu_computed_tokens: int
|
||||
num_cpu_computed_tokens: int
|
||||
|
||||
def update(self, other: "ReqMeta"):
|
||||
self.gpu_block_ids.extend(other.gpu_block_ids)
|
||||
self.cpu_block_ids.extend(other.cpu_block_ids)
|
||||
self.num_scheduled_tokens = other.num_scheduled_tokens
|
||||
self.num_computed_tokens = other.num_computed_tokens
|
||||
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
|
||||
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
requests: dict[str, ReqMeta]
|
||||
finished_req_ids: set[str]
|
||||
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[
|
||||
CPUOffloadingConnectorWorker] = None
|
||||
elif role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||
vllm_config)
|
||||
self.connector_worker = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(connector_metadata,
|
||||
CPUOffloadingConnectorMetadata)
|
||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.clear_connector_metadata()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.start_load_kv()
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(), None
|
||||
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.update_state_after_alloc(request)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.build_connector_meta(
|
||||
scheduler_output)
|
||||
return KVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request",
|
||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return True, None
|
||||
|
||||
|
||||
class CPUOffloadingConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorScheduler")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
self.num_gpu_computed_tokens: dict[str, int] = {}
|
||||
self.num_cpu_computed_tokens: dict[str, int] = {}
|
||||
self.allocated_req_ids: set[str] = set()
|
||||
self.finished_req_ids: list[str] = []
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.zmq_rpc_client.call("post_init")
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"swap_in_threshold", 0)
|
||||
else:
|
||||
self.swap_in_threshold = 0
|
||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, ori_request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||
"get_matched_num_and_touch", request)
|
||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||
self.num_cpu_computed_tokens[
|
||||
request.request_id] = num_cpu_computed_tokens
|
||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||
else:
|
||||
return 0, load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request"):
|
||||
self.allocated_req_ids.add(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
num_tokens = {}
|
||||
# process scheduled_new_reqs
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
num_tokens[req_id] = (
|
||||
req.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
# process scheduled_cached_reqs
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_tokens[req_id] = (
|
||||
cached_reqs.num_computed_tokens[idx] +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||
self.allocated_req_ids -
|
||||
scheduler_output.num_scheduled_tokens.keys())
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||
num_tokens,
|
||||
unallocated_req_ids)
|
||||
metadata = CPUOffloadingConnectorMetadata(
|
||||
requests={},
|
||||
finished_req_ids=set(self.finished_req_ids),
|
||||
)
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
gpu_block_ids = req.block_ids[0]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=req.num_computed_tokens,
|
||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||
self.num_gpu_computed_tokens.clear()
|
||||
self.num_cpu_computed_tokens.clear()
|
||||
self.allocated_req_ids.clear()
|
||||
self.finished_req_ids.clear()
|
||||
return metadata
|
||||
|
||||
def request_finished(self, ori_request: "Request"):
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
self.finished_req_ids.append(request.request_id)
|
||||
# inform metadata server to record request, and free it after finish sending
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||
request)
|
||||
|
||||
|
||||
class CPUOffloadingConnectorWorker:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorWorker")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.pp_rank = get_pp_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_rank = self.tp_group.rank_in_group
|
||||
self.tp_world_size = self.tp_group.world_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.load_stream = torch.npu.Stream()
|
||||
self.save_stream = torch.npu.Stream()
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.load_block_mapping: list[tuple[int, int]] = []
|
||||
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
|
||||
self.save_output_queue: queue.Queue[str] = queue.Queue()
|
||||
self.save_thread = threading.Thread(target=self._save_listener)
|
||||
self.save_thread.start()
|
||||
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
|
||||
# all dp shared the same metadata server, only start the process on data_rank 0
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
|
||||
config = VllmConfig()
|
||||
config.cache_config = vllm_config.cache_config
|
||||
config.parallel_config = vllm_config.parallel_config
|
||||
config.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.init_metadata_server(config)
|
||||
self._wait_for_metadata_process_start()
|
||||
|
||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||
self.metadata_thread = threading.Thread(
|
||||
target=MetadataServerProc.run_metadata_server,
|
||||
args=(vllm_config, ),
|
||||
)
|
||||
self.metadata_thread.daemon = True
|
||||
self.metadata_thread.start()
|
||||
|
||||
def _wait_for_metadata_process_start(self):
|
||||
# TODO: wait for metadata server to start, add a rpc to check if ready
|
||||
while True:
|
||||
try:
|
||||
if self.zmq_rpc_client.call("ready"):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.info(f"wait for metadata server to start, error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
for req_id, req in connector_metadata.requests.items():
|
||||
if req_id in self.requests:
|
||||
self.requests[req_id].update(req)
|
||||
req = self.requests[req_id]
|
||||
else:
|
||||
self.requests[req_id] = req
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||
req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append(
|
||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for req_id in connector_metadata.finished_req_ids:
|
||||
if req_id in self.requests:
|
||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self.load_block_mapping.clear()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||
self.gpu_kv_caches = kv_caches
|
||||
model_config = self.vllm_config.model_config
|
||||
mla_config: Optional[MLAConfig] = None
|
||||
if model_config.use_mla:
|
||||
mla_config = MLAConfig(
|
||||
model_config.hf_text_config.kv_lora_rank,
|
||||
model_config.hf_text_config.qk_rope_head_dim)
|
||||
self.cpu_kv_caches = list(
|
||||
self.zmq_rpc_client.call(
|
||||
"init_cpu_kv_caches",
|
||||
self.pp_rank,
|
||||
self.tp_rank,
|
||||
get_kv_cache_spec(self.vllm_config),
|
||||
mla_config,
|
||||
).values())
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
self.current_layer = 0
|
||||
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
|
||||
self.load_kv_layer(0)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
|
||||
self.load_stream.synchronize()
|
||||
self.current_layer += 1
|
||||
self.load_kv_layer(self.current_layer)
|
||||
|
||||
def load_kv_layer(self, layer: int):
|
||||
if layer == len(self.gpu_kv_caches):
|
||||
return
|
||||
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
|
||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||
with torch.npu.stream(self.load_stream):
|
||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||
for gpu_layer_part, cpu_layer_part in zip(
|
||||
gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(
|
||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
|
||||
def get_finished(self) -> set[str]:
|
||||
done_sending: set[str] = set()
|
||||
while True:
|
||||
try:
|
||||
id = self.save_output_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
done_sending.add(id)
|
||||
for id in done_sending:
|
||||
del self.requests[id]
|
||||
if self.tp_world_size == 1:
|
||||
return done_sending
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self.done_sending_count[req_id] += 1
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.tp_world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
self.done_sending_count[req_id] += 1
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self.done_sending_count.keys()):
|
||||
if self.done_sending_count[req_id] == self.tp_world_size:
|
||||
del self.done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
# release cpu_kv_cache after request sending finished
|
||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||
sending_finished_thread = threading.Thread(
|
||||
target=self._sending_finished, args=(all_done_sending, ))
|
||||
sending_finished_thread.daemon = True
|
||||
sending_finished_thread.start()
|
||||
|
||||
return all_done_sending
|
||||
else:
|
||||
self.tp_group.send_object(done_sending, dst=0)
|
||||
return done_sending
|
||||
|
||||
def _sending_finished(self, all_done_sending):
|
||||
for req_id in all_done_sending:
|
||||
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
|
||||
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
|
||||
|
||||
def _save_listener(self):
|
||||
save_block_mapping = []
|
||||
while True:
|
||||
req_id, req = self.save_input_queue.get()
|
||||
for i in range(
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||
self.block_size, len(req.cpu_block_ids))):
|
||||
save_block_mapping.append(
|
||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
with torch.npu.stream(self.save_stream):
|
||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||
if self.use_mla:
|
||||
start, step = self.tp_rank, self.tp_world_size
|
||||
else:
|
||||
start, step = 0, 1
|
||||
for i in range(start, len(save_block_mapping), step):
|
||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(
|
||||
cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(
|
||||
gpu_layer_part[gpu_block_id],
|
||||
non_blocking=True)
|
||||
self.save_stream.synchronize()
|
||||
self.save_output_queue.put(req_id)
|
||||
save_block_mapping.clear()
|
||||
|
||||
|
||||
# Copied from vllm_npu/worker/model_runner_v1.py.
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
forward_ctx = vllm_config.compilation_config.static_forward_context
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
ascend_config = get_ascend_config()
|
||||
use_sfa = ascend_config.use_sfa
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if use_mla and not use_sfa:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
||||
else:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype)
|
||||
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
return kv_cache_spec
|
||||
202
vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py
Normal file
202
vllm_npu/distributed/cpu_offload_manager/cpu_kv_cache_manager.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from vllm.utils import logger, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
PrefixCachingMetrics)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||
get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CPUCacheStats:
|
||||
|
||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.log_stats = log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
|
||||
self.time_sec = int(time.time())
|
||||
|
||||
def log(self):
|
||||
current_time_sec = int(time.time())
|
||||
# Log the prefix cache hit rate every 10 seconds.
|
||||
if current_time_sec - self.time_sec >= 10:
|
||||
self.time_sec = current_time_sec
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def update(self, num_tokens, num_computed_tokens):
|
||||
# Note the function is called by scheduler
|
||||
if self.log_stats and self.enable_prefix_caching:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
|
||||
def set_cache_stats(self, num_tokens, num_computed_tokens):
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.hits = num_computed_tokens
|
||||
self.prefix_cache_stats.queries = num_tokens
|
||||
self.prefix_cache_stats.requests = 1
|
||||
|
||||
|
||||
class CPUKVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_cpu_blocks: int,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
# Record kv block hashes, avoid redundant computation.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
# Record blocks touched in get_matched_num_and_touch().
|
||||
self.req_to_computed_blocks: defaultdict[
|
||||
str, list[KVCacheBlock]] = defaultdict(list)
|
||||
# Record the request that failed to allocate.
|
||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||
log_stats=True)
|
||||
# Record request that will be free after finish sending
|
||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||
|
||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (request.sampling_params.prompt_logprobs is not None):
|
||||
return 0, False
|
||||
request_id = request.request_id
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = request.block_hashes
|
||||
self.req_to_block_hashes[request_id] = block_hashes
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.single_type_manager.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
num_computed_tokens = len(computed_blocks[0]) * self.block_size
|
||||
self.req_to_computed_blocks[request_id] = computed_blocks[0]
|
||||
# We should touch these blocks in the concurrent scenarios.
|
||||
self.block_pool.touch(computed_blocks)
|
||||
|
||||
# cup prefix cache status set and log
|
||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||
num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||
self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.log()
|
||||
|
||||
return num_computed_tokens, False
|
||||
|
||||
def _release_ahead_touch(self, request_id: str):
|
||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
if computed_blocks:
|
||||
self.single_type_manager.block_pool.free_blocks(
|
||||
reversed(computed_blocks))
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
for request_id in unallocated_req_ids:
|
||||
self._free_slots(request_id)
|
||||
req_to_new_blocks = {}
|
||||
for request_id, num_tokens in req_to_num_tokens.items():
|
||||
if self.req_failed_to_allocate[request_id]:
|
||||
continue
|
||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
))
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
self._release_ahead_touch(request_id)
|
||||
self.req_failed_to_allocate[request_id] = True
|
||||
continue
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request_id, new_computed_blocks)
|
||||
# Allocate new blocks but do not cache now.
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request_id, num_tokens)
|
||||
self.req_to_num_tokens[request_id] = num_tokens
|
||||
# No need to release ref_cnt because we use officially.
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
req_to_new_blocks[request_id] = [
|
||||
block.block_id for block in new_computed_blocks + new_blocks
|
||||
]
|
||||
return req_to_new_blocks
|
||||
|
||||
def record_request_cache_and_free_slots(self, request: Request):
|
||||
logger.debug(
|
||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
self.req_to_free[request.request_id] = request
|
||||
|
||||
def cache_and_free_slots(self, request_id: str):
|
||||
logger.debug(
|
||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
if request_id not in self.req_to_free:
|
||||
logger.Error(
|
||||
f"request {request_id} not in req_to_free, maybe bug!")
|
||||
return
|
||||
request = self.req_to_free[request_id]
|
||||
if not self.req_failed_to_allocate[request_id]:
|
||||
self.single_type_manager.cache_blocks(
|
||||
request,
|
||||
self.req_to_num_tokens[request_id],
|
||||
)
|
||||
self._free_slots(request_id)
|
||||
logger.debug(
|
||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
del self.req_to_free[request_id]
|
||||
|
||||
def _free_slots(self, request_id: str):
|
||||
# This function is designed to be reentrant.
|
||||
self._release_ahead_touch(request_id)
|
||||
self.single_type_manager.free(request_id)
|
||||
self.req_to_block_hashes.pop(request_id, None)
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
self.req_failed_to_allocate.pop(request_id, None)
|
||||
self.req_to_num_tokens.pop(request_id, None)
|
||||
269
vllm_npu/distributed/cpu_offload_manager/metadata.py
Normal file
269
vllm_npu/distributed/cpu_offload_manager/metadata.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.utils import get_dtype_size, logger, make_zmq_socket
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_npu.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
|
||||
CPUKVCacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAConfig:
|
||||
nope_dim: int
|
||||
rope_dim: int
|
||||
|
||||
|
||||
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
return None
|
||||
|
||||
|
||||
class MetadataServer:
|
||||
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
|
||||
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||
|
||||
class ZMQRPCClient:
|
||||
|
||||
def __init__(self, identity=f"worker-{os.getpid()}"):
|
||||
logger.info(f"metadata client for worker {identity} started")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.DEALER, # type: ignore
|
||||
bind=False,
|
||||
identity=identity.encode(),
|
||||
linger=0)
|
||||
|
||||
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||
request = (func_name, args, kwargs)
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(request))
|
||||
_ = self.socket.recv()
|
||||
response = pickle.loads(self.socket.recv())
|
||||
result, error = response
|
||||
if error:
|
||||
logger.exception(f"call metadata sever error: {error}")
|
||||
raise error
|
||||
if func_name == "init_cpu_kv_caches":
|
||||
(memory_dict, layer_size, layer_dtype, mla_config) = result
|
||||
# shared_memory_dict is recorded in self to close
|
||||
self.shared_memory_dict = memory_dict
|
||||
result = {}
|
||||
for key, shm in memory_dict.items():
|
||||
tensor = torch.frombuffer(
|
||||
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||
if mla_config is not None:
|
||||
tensor = tensor.split(
|
||||
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||
result[key] = tensor
|
||||
return result
|
||||
|
||||
def __del__(self):
|
||||
# will be finalized by outer process
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
if hasattr(self, 'shared_memory_dict'):
|
||||
for shm in self.shared_memory_dict.values():
|
||||
shm.close()
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||
assert kv_transfer_config is not None
|
||||
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
||||
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.ROUTER, # type: ignore
|
||||
bind=True,
|
||||
linger=0)
|
||||
self.functions: dict[str, Callable] = {
|
||||
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||
"post_init": self.post_init,
|
||||
"ready": self.ready,
|
||||
}
|
||||
self.shared_memory = {} # type: ignore
|
||||
self.num_cpu_blocks = -1
|
||||
|
||||
@staticmethod
|
||||
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
|
||||
try:
|
||||
existing_shm = SharedMemory(name=name, create=False)
|
||||
existing_shm.close()
|
||||
existing_shm.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return SharedMemory(name=name, create=True, size=size)
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def init_cpu_kv_caches(
|
||||
self,
|
||||
pp_rank: int,
|
||||
tp_rank: int,
|
||||
kv_cache_specs: dict[str, AttentionSpec],
|
||||
mla_config: MLAConfig,
|
||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
||||
MLAConfig]:
|
||||
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||
# follow the assumption that each layer has the same spec
|
||||
layer = next(iter(kv_cache_specs.values()))
|
||||
assert all([
|
||||
layer.page_size_bytes == any.page_size_bytes
|
||||
for any in kv_cache_specs.values()
|
||||
])
|
||||
# mla shares the same kv cache among different tp
|
||||
if layer.use_mla:
|
||||
tp_rank = 0
|
||||
if (pp_rank, tp_rank) in self.shared_memory:
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
available_memory = self.available_memory
|
||||
shared_memory_dict = {}
|
||||
if layer.use_mla:
|
||||
available_memory //= self.pipeline_parallel_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
else:
|
||||
available_memory //= self.world_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||
for layer_name in kv_cache_specs.keys():
|
||||
# only this format can share during ZeroMQ+pickle
|
||||
shared_memory_dict[
|
||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||
if layer.use_mla:
|
||||
assert mla_config is not None
|
||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, mla_config)
|
||||
else:
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, None)
|
||||
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||
self.num_cpu_blocks = num_blocks
|
||||
self.layer = layer
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
|
||||
def post_init(self):
|
||||
# different processors in data parallel may call multiple times
|
||||
if hasattr(self, 'cpu_block_manager'):
|
||||
return
|
||||
# do shared_memory() at least once
|
||||
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||
assert self.num_cpu_blocks >= 0
|
||||
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
||||
self.num_cpu_blocks)
|
||||
self.functions.update({
|
||||
"get_matched_num_and_touch":
|
||||
self.cpu_block_manager.get_matched_num_and_touch,
|
||||
"allocate_slots":
|
||||
self.cpu_block_manager.allocate_slots,
|
||||
"record_request_cache_and_free_slots":
|
||||
self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||
"cache_and_free_slots":
|
||||
self.cpu_block_manager.cache_and_free_slots,
|
||||
})
|
||||
|
||||
def serve_step(self):
|
||||
client_id = self.socket.recv()
|
||||
_ = self.socket.recv()
|
||||
raw_msg = self.socket.recv()
|
||||
try:
|
||||
func_name, args, kwargs = pickle.loads(raw_msg)
|
||||
except Exception as e:
|
||||
response = (None, Exception(f"Invalid request: {str(e)}"))
|
||||
else:
|
||||
if func_name in self.functions:
|
||||
try:
|
||||
result = self.functions[func_name](*args, **kwargs)
|
||||
response = (result, None) # type: ignore
|
||||
except Exception as e:
|
||||
logger.exception(f"metadata execute error: {e}")
|
||||
response = (None, e) # type: ignore
|
||||
else:
|
||||
response = (None, NameError(f"Function {func_name} not found"))
|
||||
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(response))
|
||||
|
||||
def shutdown(self):
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
||||
"ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
for cached in self.shared_memory.values():
|
||||
for shm in cached[0].values():
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
class MetadataServerProc:
|
||||
|
||||
@staticmethod
|
||||
def run_metadata_server(vllm_config: VllmConfig):
|
||||
if (not vllm_config.cache_config.enable_prefix_caching
|
||||
or get_cpu_offload_connector(vllm_config) is None):
|
||||
return
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||
# signal.signal(signal.SIGINT, _signal_handler)
|
||||
metadata_server: Optional[MetadataServer] = None
|
||||
try:
|
||||
metadata_server = MetadataServer(vllm_config)
|
||||
logger.info("Metadata server started.")
|
||||
while True:
|
||||
metadata_server.serve_step()
|
||||
except SystemExit:
|
||||
logger.info("Metadata server exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Metadata server error: {e}.")
|
||||
raise e
|
||||
finally:
|
||||
if metadata_server is not None:
|
||||
metadata_server.shutdown()
|
||||
165
vllm_npu/distributed/device_communicators/pyhccl.py
Normal file
165
vllm_npu/distributed/device_communicators/pyhccl.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.distributed.device_communicators.pyhccl_wrapper import (
|
||||
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
|
||||
hcclRedOpTypeEnum, hcclUniqueId)
|
||||
from vllm_npu.utils import current_stream
|
||||
|
||||
|
||||
class PyHcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyHcclCommunicator to. If None,
|
||||
it will be bind to f"npu:{local_rank}".
|
||||
library_path: the path to the HCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.HCCL, (
|
||||
"PyHcclCommunicator should be attached to a non-HCCL group.")
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
try:
|
||||
self.hccl = HCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing HCCL library
|
||||
# e.g. in a non-NPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using pyhccl")
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"npu:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from HCCL
|
||||
with torch.npu.device(device):
|
||||
self.unique_id = self.hccl.hcclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = hcclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
|
||||
# hccl communicator and stream will use this device
|
||||
# `torch.npu.device` is a context manager that changes the
|
||||
# current npu device to the specified one
|
||||
with torch.npu.device(device):
|
||||
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# hccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
hcclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
aclrtStream_t(stream.npu_stream))
|
||||
return out_tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this hccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
buffer = buffer_type(tensor.data_ptr())
|
||||
self.hccl.hcclBroadcast(buffer, tensor.numel(),
|
||||
hcclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, aclrtStream_t(stream.npu_stream))
|
||||
253
vllm_npu/distributed/device_communicators/pyhccl_wrapper.py
Normal file
253
vllm_npu/distributed/device_communicators/pyhccl_wrapper.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_npu.utils import find_hccl_library
|
||||
|
||||
# export types and functions from hccl to Python ===
|
||||
# for the original hccl definition, please check
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
|
||||
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
|
||||
|
||||
hcclResult_t = ctypes.c_int
|
||||
hcclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class hcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 4108)]
|
||||
|
||||
|
||||
aclrtStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
hcclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclDataTypeEnum:
|
||||
hcclInt8 = 0
|
||||
hcclInt16 = 1
|
||||
hcclInt32 = 2
|
||||
hcclFloat16 = 3
|
||||
hcclFloat32 = 4
|
||||
hcclInt64 = 5
|
||||
hcclUint64 = 6
|
||||
hcclUint8 = 7
|
||||
hcclUint16 = 8
|
||||
hcclUint32 = 9
|
||||
hcclFloat64 = 10
|
||||
hcclBfloat16 = 11
|
||||
hcclInt128 = 12
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.hcclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.hcclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.hcclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.hcclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.hcclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.hcclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.hcclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.hcclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
hcclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class hcclRedOpTypeEnum:
|
||||
hcclSum = 0
|
||||
hcclProd = 1
|
||||
hcclMax = 2
|
||||
hcclMin = 3
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.hcclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.hcclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.hcclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.hcclMin
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class HCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* HcclGetErrorString(HcclResult code);
|
||||
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
|
||||
|
||||
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
|
||||
Function("HcclGetRootInfo", hcclResult_t,
|
||||
[ctypes.POINTER(hcclUniqueId)]),
|
||||
|
||||
# HcclResult HcclCommInitRootInfo(
|
||||
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
|
||||
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
|
||||
Function("HcclCommInitRootInfo", hcclResult_t, [
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclUniqueId),
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(hcclComm_t),
|
||||
]),
|
||||
|
||||
# HcclResult HcclAllReduce(
|
||||
# void *sendBuf, void *recvBuf, uint64_t count,
|
||||
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
|
||||
# aclrtStream stream);
|
||||
Function("HcclAllReduce", hcclResult_t, [
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
hcclRedOp_t,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclBroadcast(
|
||||
# void *buf, uint64_t count,
|
||||
# HcclDataType dataType, uint32_t root,
|
||||
# HcclComm comm, aclrtStream stream);
|
||||
Function("HcclBroadcast", hcclResult_t, [
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
hcclDataType_t,
|
||||
ctypes.c_int,
|
||||
hcclComm_t,
|
||||
aclrtStream_t,
|
||||
]),
|
||||
|
||||
# HcclResult HcclCommDestroy(HcclComm comm);
|
||||
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the correspongding directory
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_hccl_library()
|
||||
|
||||
try:
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
HCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = HCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load HCCL library from %s. "
|
||||
"It is expected if you are not running on Ascend NPUs."
|
||||
"Otherwise, the hccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable HCCL_SO_PATH"
|
||||
" to point to the correct hccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in HCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
for func in HCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def hcclGetErrorString(self, result: hcclResult_t) -> str:
|
||||
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def HCCL_CHECK(self, result: hcclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.hcclGetErrorString(result)
|
||||
raise RuntimeError(f"HCCL error: {error_str}")
|
||||
|
||||
def hcclGetUniqueId(self) -> hcclUniqueId:
|
||||
unique_id = hcclUniqueId()
|
||||
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
|
||||
rank: int) -> hcclComm_t:
|
||||
comm = hcclComm_t()
|
||||
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
|
||||
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
|
||||
return comm
|
||||
|
||||
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
# `datatype` actually should be `hcclDataType_t`
|
||||
# and `op` should be `hcclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
|
||||
root: int, comm: hcclComm_t,
|
||||
stream: aclrtStream_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
|
||||
root, comm, stream))
|
||||
|
||||
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
|
||||
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HCCLLibrary",
|
||||
"hcclDataTypeEnum",
|
||||
"hcclRedOpTypeEnum",
|
||||
"hcclUniqueId",
|
||||
"hcclComm_t",
|
||||
"aclrtStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
994
vllm_npu/distributed/llmdatadist_c_mgr_connector.py
Normal file
994
vllm_npu/distributed/llmdatadist_c_mgr_connector.py
Normal file
@@ -0,0 +1,994 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||
LLMException, LLMRole)
|
||||
from vllm import envs
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import get_ip, logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_npu.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32
|
||||
}
|
||||
|
||||
|
||||
class LLMDataDistCMgrEvent(Enum):
|
||||
ReqForMetadata = 0
|
||||
ReqForFinished = 1
|
||||
|
||||
|
||||
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
|
||||
super_pod_id: str
|
||||
server_id: str
|
||||
device_id: str
|
||||
device_ip: str
|
||||
super_device_id: str
|
||||
cluster_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: str
|
||||
engine_id: str
|
||||
remote_tp_size: str
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
|
||||
def add_new_req(self, request_id: str, local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any]):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_tp_size=kv_transfer_params["remote_tp_size"],
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: Optional[
|
||||
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
|
||||
vllm_config, self.engine_id)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(
|
||||
self,
|
||||
kv_caches: dict[
|
||||
str, # type: ignore[override]
|
||||
Tuple[torch.Tensor]]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
LLMDataDistCMgrConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata, **kwargs) -> None:
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
"""LLMDataDistCMgrConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorScheduler():
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
self.local_ip = get_ip()
|
||||
# Can not retrieve the parallel config since it is not initialized.
|
||||
self.local_dp_rank = None
|
||||
self.tp_size = None
|
||||
if vllm_config.parallel_config.data_parallel_external_lb:
|
||||
dp_rank_local = vllm_config.parallel_config.data_parallel_rank
|
||||
else:
|
||||
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.port = dp_rank_local * tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT
|
||||
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
assert num_computed_tokens % self.block_size == 0
|
||||
# Note: We use the full token count as transmit data here.
|
||||
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
||||
return count, count > 0
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_externel_tokens: int):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
|
||||
)
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port", "remote_tp_size")):
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
else:
|
||||
logger.warning("" \
|
||||
f"Invalid KVTransferParams {params}, This request will be discard")
|
||||
else:
|
||||
assert num_externel_tokens == 0
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = LLMDataDistCMgrConnectorMetadata()
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params)
|
||||
|
||||
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s", request.status, params)
|
||||
|
||||
if (params is None or not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||
return False, None
|
||||
|
||||
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
|
||||
# we just transfer any data that computed from prefill node
|
||||
# note: there might be some issue on this, check it if there is any unexpected result
|
||||
computed_block_ids = block_ids
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
self._reqs_need_send[request.request_id] = time.perf_counter(
|
||||
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.local_ip,
|
||||
remote_port=self.port,
|
||||
remote_tp_size=str(
|
||||
self.vllm_config.parallel_config.tensor_parallel_size),
|
||||
)
|
||||
|
||||
|
||||
class LLMDataDistCMgrConnectorWorker():
|
||||
"""
|
||||
Implementation of Worker side methods
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
|
||||
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
|
||||
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
|
||||
self.local_rank_on_node = get_world_group().rank % (
|
||||
vllm_config.parallel_config.data_parallel_size_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
self.local_rank = get_world_group().local_rank
|
||||
if vllm_config.parallel_config.data_parallel_external_lb:
|
||||
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
else:
|
||||
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.rank = get_world_group().rank
|
||||
self.local_ip = get_ip()
|
||||
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
|
||||
self.local_agent_metadata: Optional[
|
||||
LLMDataDistCMgrAgentMetadata] = None
|
||||
self.vllm_config = vllm_config
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
self.thread_lock = threading.Lock()
|
||||
|
||||
self.llm_datadist_role = None
|
||||
self.llm_datadist_remote_role = None
|
||||
if self.kv_transfer_config.kv_role == "kv_producer":
|
||||
self.llm_datadist_role = LLMRole.PROMPT
|
||||
self.llm_datadist_remote_role = LLMRole.DECODER
|
||||
elif self.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.llm_datadist_role = LLMRole.DECODER
|
||||
self.llm_datadist_remote_role = LLMRole.PROMPT
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
|
||||
)
|
||||
|
||||
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
|
||||
self.linked_cluster: dict[Any, Any] = {}
|
||||
self.prefill_device_list: list[tuple[int, int]] = []
|
||||
self.decode_device_list: list[tuple[int, int]] = []
|
||||
global_rank_table = self.read_offline_rank_table()
|
||||
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
|
||||
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
self.finished_reqs: set[str] = set()
|
||||
self.soc_info = get_ascend_soc_version()
|
||||
# Set hccl deterministic for model execute
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
set[int]] = defaultdict(set)
|
||||
self.reqs_to_send: dict[str, float] = {}
|
||||
|
||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||
assert self.local_agent_metadata is not None
|
||||
port = envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
|
||||
url = f"tcp://{envs_ascend.vllm_npu_LLMDD_RPC_IP}:{port}"
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_decoder = msgspec.msgpack.Decoder()
|
||||
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
|
||||
logger.debug(f"Start to listen to address: {url}")
|
||||
logger.debug(
|
||||
f"The local agent metadata have {len(msg_to_send)} bytes here")
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
|
||||
)
|
||||
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
|
||||
event.set()
|
||||
while True:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
event_msg, decode_msg = msg_decoder.decode(msg)
|
||||
event_msg = LLMDataDistCMgrEvent(event_msg)
|
||||
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
|
||||
if "cluster_id" in decode_msg:
|
||||
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
|
||||
)
|
||||
sock.send_multipart((identity, b"", msg_to_send))
|
||||
self.add_remote_agent(decode_msg)
|
||||
else:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
|
||||
)
|
||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||
finished_req_id = decode_msg[0]
|
||||
with self.thread_lock:
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
if finished_req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
del self.reqs_to_send[finished_req_id]
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||
)
|
||||
|
||||
def init_llm_datadist(self):
|
||||
assert self.local_agent_metadata is not None
|
||||
llm_config = LLMConfig()
|
||||
llm_config.device_id = self.local_rank
|
||||
llm_config.sync_kv_timeout = get_transfer_timeout_value()
|
||||
llm_config.enable_switch_role = True
|
||||
llm_config.enable_cache_manager = True
|
||||
llm_config.enable_remote_cache_accessible = True
|
||||
llm_config_options = llm_config.generate_options()
|
||||
self.llm_datadist.init(llm_config_options)
|
||||
self.cache_manager = self.llm_datadist.cache_manager
|
||||
logger.info(
|
||||
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
|
||||
)
|
||||
|
||||
def read_offline_rank_table(self):
|
||||
assert (
|
||||
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
|
||||
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
||||
with open(rank_table_path, "r", encoding="utf-8") as f:
|
||||
global_rank_table = json.load(f)
|
||||
decode_device_list = global_rank_table["decode_device_list"]
|
||||
for decode_device in decode_device_list:
|
||||
server_id = decode_device["server_id"]
|
||||
device_id = decode_device["device_id"]
|
||||
self.decode_device_list.append((server_id, device_id))
|
||||
prefill_device_list = global_rank_table["prefill_device_list"]
|
||||
for prefill_device in prefill_device_list:
|
||||
server_id = prefill_device["server_id"]
|
||||
device_id = prefill_device["device_id"]
|
||||
self.prefill_device_list.append((server_id, device_id))
|
||||
|
||||
# global_rank_table = json.dumps(global_rank_table)
|
||||
return global_rank_table
|
||||
|
||||
@staticmethod
|
||||
def _get_visible_devices() -> Callable[[str], bool]:
|
||||
"""
|
||||
Return a test function that check if the given device ID is visible.
|
||||
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
|
||||
"""
|
||||
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
if not visible_devices:
|
||||
return lambda device_id: True
|
||||
visible_device_list = visible_devices.split(",")
|
||||
return lambda device_id: device_id in visible_device_list
|
||||
|
||||
def read_agent_metadata(self, global_rank_table):
|
||||
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
|
||||
devices_type_list = []
|
||||
agent_metadata = None
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
elif self.llm_datadist_role == LLMRole.DECODER:
|
||||
devices_type_list.append("decode_device_list")
|
||||
else:
|
||||
devices_type_list.append("prefill_device_list")
|
||||
devices_type_list.append("decode_device_list")
|
||||
for device_type in devices_type_list:
|
||||
device_list = global_rank_table[device_type]
|
||||
device_list = [
|
||||
d for d in device_list if d.get("server_id") == self.local_ip
|
||||
and device_filter(d.get("device_id", ""))
|
||||
]
|
||||
if len(device_list) <= self.tp_rank:
|
||||
continue
|
||||
device_info = device_list[self.tp_rank]
|
||||
super_pod_id_ = device_info.get("super_pod_id", None)
|
||||
server_id_ = device_info["server_id"]
|
||||
device_id_ = device_info["device_id"]
|
||||
device_ip_ = device_info["device_ip"]
|
||||
super_device_id_ = device_info.get("super_device_id", None)
|
||||
cluster_id_ = int(device_info["cluster_id"])
|
||||
agent_metadata = LLMDataDistCMgrAgentMetadata(
|
||||
super_pod_id=super_pod_id_,
|
||||
server_id=server_id_,
|
||||
device_id=device_id_,
|
||||
device_ip=device_ip_,
|
||||
super_device_id=super_device_id_,
|
||||
cluster_id=cluster_id_,
|
||||
)
|
||||
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
|
||||
return agent_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
assert len(first_kv_cache_tuple) > 1
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
|
||||
self.block_len = math.prod(block_shape)
|
||||
self.cache_addr: list[int] = []
|
||||
alignment = 2 * 1024 * 1024
|
||||
if self.use_mla:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
self.cache = (cache_k_normed, cache_k_pe)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
elif self.use_sparse:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
cache_k_idx_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
k_idx = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
||||
1], cache_or_caches[2]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
||||
cache_k_idx_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_idx = CacheDesc(
|
||||
len(self.cache_addr[2]), [*k_idx.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=2)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
||||
cache_desc_k_idx)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
||||
cache_key_k_idx)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
cache_k_idx = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
||||
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
base_addr = cache.data_ptr()
|
||||
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
self.cache_addr.append(base_addr)
|
||||
# register paged kv cache into the llm_cache manager
|
||||
self.cache_desc = CacheDesc(
|
||||
len(self.cache_addr), [*cache.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
self.cache_key = BlocksCacheKey(
|
||||
cluster_id=int(self.local_agent_metadata.cluster_id))
|
||||
logger.info(
|
||||
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
|
||||
)
|
||||
try:
|
||||
self.cache = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc, self.cache_addr, self.cache_key)
|
||||
logger.info(
|
||||
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
self.ready_event = threading.Event()
|
||||
self.metadata_agent_listener_t = threading.Thread(
|
||||
target=self.listen_for_agent_metadata_req,
|
||||
args=(self.ready_event, ),
|
||||
daemon=True,
|
||||
name="metadata_agent_listener")
|
||||
self.metadata_agent_listener_t.start()
|
||||
self.ready_event.wait()
|
||||
|
||||
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
|
||||
futures = []
|
||||
for req_id, meta in metadata.requests.items():
|
||||
logger.debug(f"Start to transmit {req_id}")
|
||||
future = self.executor.submit(
|
||||
self._read_blocks,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_ip=meta.remote_host,
|
||||
remote_port=int(meta.remote_port),
|
||||
remote_engine_id=meta.engine_id,
|
||||
request_id=req_id,
|
||||
remote_tp_size=meta.remote_tp_size,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
def handle_exception(future):
|
||||
if future.exception():
|
||||
logger.error(f"KV transfer task failed: {future.exception()}")
|
||||
|
||||
for future in futures:
|
||||
future.add_done_callback(handle_exception)
|
||||
self.reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||
assert self.local_agent_metadata is not None
|
||||
remote_cluster_id = metadata.cluster_id
|
||||
if remote_cluster_id in self.linked_cluster:
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
|
||||
)
|
||||
return remote_cluster_id
|
||||
remote_super_pod_id = metadata.super_pod_id
|
||||
remote_server_id = metadata.server_id
|
||||
is_same_server = remote_server_id == self.local_agent_metadata.server_id
|
||||
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
prefill_metadata = self.local_agent_metadata
|
||||
decode_metadata = metadata
|
||||
else:
|
||||
prefill_metadata = metadata
|
||||
decode_metadata = self.local_agent_metadata
|
||||
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
|
||||
cluster_rank_info = {
|
||||
prefill_metadata.cluster_id: 0,
|
||||
decode_metadata.cluster_id: 1
|
||||
}
|
||||
rank_table = {}
|
||||
rank_table["version"] = "1.2"
|
||||
rank_table["server_count"] = "1" if is_same_server else "2"
|
||||
rank_table["status"] = "completed"
|
||||
|
||||
# generate server_list for rank table
|
||||
rank_table["server_list"] = [] # type: ignore[assignment]
|
||||
decode_server_device_info = None
|
||||
prefill_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", prefill_metadata.device_id
|
||||
), ("device_ip", prefill_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
prefill_metadata.super_device_id), ("rank_id", "0")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
prefill_metadata.server_id
|
||||
}
|
||||
if is_same_server:
|
||||
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
|
||||
{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
})
|
||||
else:
|
||||
decode_server_device_info = {
|
||||
"device": [{
|
||||
k: v
|
||||
for k, v in [(
|
||||
"device_id", decode_metadata.device_id
|
||||
), ("device_ip", decode_metadata.device_ip
|
||||
), ("super_device_id",
|
||||
decode_metadata.super_device_id), ("rank_id", "1")]
|
||||
if v is not None
|
||||
}],
|
||||
"server_id":
|
||||
decode_metadata.server_id
|
||||
}
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
prefill_server_device_info)
|
||||
if decode_server_device_info is not None:
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
decode_server_device_info)
|
||||
|
||||
if self.soc_info == AscendSocVersion.A3:
|
||||
# generate super_pod_list for rank table
|
||||
super_pod_list = []
|
||||
prefill_super_pod_info = {
|
||||
"super_pod_id": prefill_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": prefill_metadata.server_id
|
||||
}],
|
||||
}
|
||||
if is_same_pod and not is_same_server:
|
||||
prefill_super_pod_info[
|
||||
"server_list"].append( # type: ignore[attr-defined]
|
||||
{"server_id": decode_metadata.server_id})
|
||||
super_pod_list.append(prefill_super_pod_info)
|
||||
if not is_same_pod:
|
||||
decode_super_pod_id = {
|
||||
"super_pod_id": decode_metadata.super_pod_id,
|
||||
"server_list": [{
|
||||
"server_id": decode_metadata.server_id
|
||||
}],
|
||||
}
|
||||
super_pod_list.append(decode_super_pod_id)
|
||||
rank_table[
|
||||
"super_pod_list"] = super_pod_list # type: ignore[assignment]
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
|
||||
)
|
||||
logger.info(f"rank table \n{rank_table}")
|
||||
logger.info(f"comm name: {comm_name}")
|
||||
logger.info(f"cluster rank info: {cluster_rank_info}")
|
||||
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
|
||||
json.dumps(rank_table))
|
||||
while True:
|
||||
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
|
||||
if ret == llm_datadist.RegisterMemStatus.OK:
|
||||
logger.info(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
|
||||
)
|
||||
break
|
||||
elif ret == llm_datadist.RegisterMemStatus.FAILED:
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
|
||||
)
|
||||
time.sleep(1)
|
||||
logger.info("Checking query_register_mem_status again")
|
||||
self.linked_cluster.update({remote_cluster_id: comm_id})
|
||||
logger.info(f"cached linked cluster: {self.linked_cluster}")
|
||||
logger.info(
|
||||
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
|
||||
)
|
||||
return remote_cluster_id
|
||||
|
||||
def remove_remote_agent(self, cluster_id: int):
|
||||
if cluster_id not in self.linked_cluster:
|
||||
logger.warning(
|
||||
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
|
||||
)
|
||||
comm_id = self.linked_cluster[cluster_id]
|
||||
try:
|
||||
self.llm_datadist.unlink(comm_id)
|
||||
self.linked_cluster.pop(cluster_id)
|
||||
except LLMException:
|
||||
logger.error(
|
||||
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully remove remote client with cluster id {cluster_id} !"
|
||||
)
|
||||
|
||||
def connect_to_remote_agent(self, host: str, port: int) -> int:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Querying metadata from url: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
logger.info("Try request remote metadata from socket......")
|
||||
sock.send(msg_send)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder()
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
|
||||
logger.info(f"recving metadata: {metadata}")
|
||||
cluster_id = self.add_remote_agent(metadata)
|
||||
return cluster_id
|
||||
|
||||
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
|
||||
for port in ports:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}"
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_ip: str,
|
||||
remote_port: int,
|
||||
remote_engine_id: str,
|
||||
request_id: str,
|
||||
remote_tp_size: str,
|
||||
):
|
||||
# if remote_ip not in self.linked_cluster:
|
||||
tp_offset = self.tp_rank % int(remote_tp_size)
|
||||
remote_cluster_id = self.connect_to_remote_agent(
|
||||
remote_ip, remote_port + tp_offset)
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
return
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
logger.info(f"remote cluster id is: {remote_cluster_id}")
|
||||
if self.use_mla:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
elif self.use_sparse:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
remote_cache_key_k_idx = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=2)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_idx,
|
||||
self.cache[2], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key,
|
||||
self.cache, # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
remote_ports = list(
|
||||
range(remote_port + self.tp_rank,
|
||||
remote_port + int(remote_tp_size), self.tp_size))
|
||||
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
|
||||
with self.thread_lock:
|
||||
self.finished_reqs.add(request_id)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requuests."""
|
||||
now = time.perf_counter()
|
||||
with self.thread_lock:
|
||||
while self.reqs_to_send:
|
||||
req_id, expires = next(iter(self.reqs_to_send.items()))
|
||||
if now < expires:
|
||||
break
|
||||
logger.warning(
|
||||
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
|
||||
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
||||
)
|
||||
if req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(req_id)
|
||||
del self.reqs_to_send[req_id]
|
||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||
self.finished_reqs.clear()
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
return req_ids_to_ret, None
|
||||
else:
|
||||
return None, req_ids_to_ret
|
||||
|
||||
|
||||
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any,
|
||||
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
|
||||
socket.bind(addr)
|
||||
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
|
||||
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
|
||||
socket.connect(addr)
|
||||
else:
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
0
vllm_npu/distributed/mooncake/__init__.py
Normal file
0
vllm_npu/distributed/mooncake/__init__.py
Normal file
449
vllm_npu/distributed/mooncake/config_data.py
Normal file
449
vllm_npu/distributed/mooncake/config_data.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import array
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import cdiv, logger
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeEngineMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" world size when running under a distributed setting """
|
||||
world_size: int
|
||||
""" worker id when running under a distributed setting """
|
||||
worker_id: int
|
||||
""" the format of kv tensors """
|
||||
kv_dtype: torch.dtype
|
||||
""" the shape of kv tensors """
|
||||
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
|
||||
kv_shape: tuple[int, int, int, int, int]
|
||||
block_size: int = 128
|
||||
""" whether use MLA"""
|
||||
use_mla: bool = False
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class MooncakeEngineKey:
|
||||
model_name: str
|
||||
world_size: int
|
||||
worker_id: int
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}")
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerMooncakeEngineKey(
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
def to_dict(self):
|
||||
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
|
||||
return {
|
||||
"__type__": "CacheEngineKey",
|
||||
"model_name": self.model_name,
|
||||
"world_size": self.world_size,
|
||||
"worker_id": self.worker_id,
|
||||
"chunk_hash": self.chunk_hash,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return MooncakeEngineKey(
|
||||
model_name=d["model_name"],
|
||||
world_size=d["world_size"],
|
||||
worker_id=d["worker_id"],
|
||||
chunk_hash=d["chunk_hash"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerMooncakeEngineKey(MooncakeEngineKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: MooncakeEngineMetadata,
|
||||
):
|
||||
self.metadata = metadata
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return MooncakeEngineKey(
|
||||
self.metadata.model_name,
|
||||
self.metadata.world_size,
|
||||
self.metadata.worker_id,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def _hash(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
prefix_hash: str,
|
||||
) -> str:
|
||||
# TODO: change it to a more efficient hash function
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
|
||||
elif isinstance(tokens, list):
|
||||
tokens_bytes = array.array("I", tokens).tobytes()
|
||||
return hashlib.sha256(prefix_hash.encode("ascii") +
|
||||
tokens_bytes).hexdigest()
|
||||
|
||||
def _chunk_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
) -> Iterable[Union[torch.Tensor, List[int]]]:
|
||||
"""
|
||||
Chunk the tokens into chunks of size self.metadata.block_size.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
device: the target device after chunking
|
||||
|
||||
:return: a generator of chunks of tokens, each with
|
||||
shape [metadata.block_size]
|
||||
"""
|
||||
for i in range(0, len(tokens), self.metadata.block_size):
|
||||
yield tokens[i:i + self.metadata.block_size]
|
||||
|
||||
def _prefix_hash(
|
||||
self,
|
||||
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
|
||||
) -> Iterable[str]:
|
||||
prefix_hash = ''
|
||||
for token_chunk in token_chunks:
|
||||
prefix_hash = self._hash(token_chunk, prefix_hash)
|
||||
yield prefix_hash
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if mask is not None:
|
||||
num_falses = mask.numel() - mask.long().sum().item()
|
||||
else:
|
||||
num_falses = 0
|
||||
|
||||
if num_falses % self.metadata.block_size != 0:
|
||||
raise ValueError(
|
||||
"The number of Falses in the mask is not a multiple of the chunk size."
|
||||
)
|
||||
total_len = len(tokens)
|
||||
|
||||
token_chunks = self._chunk_tokens(tokens)
|
||||
prefix_hashes = self._prefix_hash(token_chunks)
|
||||
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(prefix_hashes):
|
||||
start_idx = chunk_id * self.metadata.block_size
|
||||
end_idx = min(start_idx + self.metadata.block_size, total_len)
|
||||
if start_idx < num_falses:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in mooncake
|
||||
mooncake_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveSpec:
|
||||
# Skip already saved tokens
|
||||
skip_leading_tokens: int
|
||||
# Whether the scheduler allow us to save the tokens
|
||||
can_save: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_ids: list[int]
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
|
||||
# list[list[int]]
|
||||
# Need to check the type of request.block_ids
|
||||
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_ids.extend(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
|
||||
block_ids: list[int]
|
||||
# # Slot mapping if exchange for block_id
|
||||
# slot_mapping: torch.Tensor
|
||||
# Skip save or not
|
||||
save_spec: Optional[SaveSpec] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_ids = tracker.token_ids
|
||||
input_token_len = len(input_token_ids)
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
skip_leading_tokens = tracker.num_saved_tokens
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
|
||||
|
||||
# Calculate the token ids and slot mappings for load and save
|
||||
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
|
||||
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.mooncake_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
save_spec=save_spec,
|
||||
load_spec=load_spec,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerMooncakeEngineKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
use_ascend_direct: bool
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size", 3355443200),
|
||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
293
vllm_npu/distributed/mooncake/kv_transfer.py
Normal file
293
vllm_npu/distributed/mooncake/kv_transfer.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
|
||||
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches_base_addr = local_kv_caches_base_addr
|
||||
self.block_len = block_len
|
||||
self.token_database = token_database
|
||||
self.block_size = block_size
|
||||
self.done_task_lock = threading.Lock()
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
current_event: Optional[torch.npu.Event] = None,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"tokens": tokens,
|
||||
"block_ids": block_ids,
|
||||
"mask": mask,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
"current_event": current_event,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
current_event = req_meta["current_event"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
if key_list:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(key, addr, size)
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
num_layers: int):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.put(key, addr, size)
|
||||
if req_meta.layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
639
vllm_npu/distributed/mooncake/mooncake_engine.py
Normal file
639
vllm_npu/distributed/mooncake/mooncake_engine.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# Standard
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import get_kv_cache_torch_dtype, logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
|
||||
MooncakeEngineMetadata)
|
||||
from vllm_npu.distributed.mooncake.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class MooncakeEngine:
|
||||
#The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_layerwize: bool,
|
||||
):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla):
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = parallel_config.rank
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"register_buffer", False)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.current_layer = 0
|
||||
# self.use_mla = first_kv_cache_tuple[0].size(
|
||||
# -1) != first_kv_cache_tuple[1].size(-1)
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_dtype = get_kv_cache_torch_dtype(
|
||||
vllm_config.cache_config.cache_dtype, model_config.dtype)
|
||||
self.hidden_dim_size = num_kv_head * head_size
|
||||
if self.use_mla:
|
||||
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
|
||||
else:
|
||||
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
|
||||
head_size)
|
||||
self.metadata = MooncakeEngineMetadata(
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
self.block_size,
|
||||
self.use_mla,
|
||||
)
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata)
|
||||
|
||||
self.m_store = Mooncakestore(parallel_config)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
self._register(base_addr, region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
self._register(base_addr, region_len)
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database, self.block_len,
|
||||
self.block_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _register(self, ptr, length):
|
||||
logger.debug(
|
||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
||||
try:
|
||||
self.m_store.register_buffer(ptr, length)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Mooncake memory registration failed. Error is: {e}")
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
self.current_layer = 0
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
tokens = request.token_ids
|
||||
req_id = request.req_id
|
||||
if (load_spec.mooncake_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.mooncake_cached_tokens
|
||||
== tokens.shape[0] - 1):
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
|
||||
else:
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
|
||||
masked_token_count = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
token_mask = torch.ones_like(tokens, dtype=torch.bool)
|
||||
token_mask[:masked_token_count] = False
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
tokens,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
)
|
||||
else:
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list,
|
||||
blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, _ = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
for layerwise_retriever in self.layerwise_retrievers:
|
||||
ret_token_mask = next(layerwise_retriever)
|
||||
if self.current_layer == self.num_layers - 1:
|
||||
assert ret_token_mask is not None
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: MooncakeConnectorMetadata) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_ids,
|
||||
mask=store_mask,
|
||||
block_ids=request.block_ids,
|
||||
)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
next(layerwise_storer)
|
||||
except Exception:
|
||||
raise
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_ids,
|
||||
request.block_ids,
|
||||
store_mask,
|
||||
request.is_last_chunk,
|
||||
current_event,
|
||||
)
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the KV transfer which
|
||||
will be passed into the npu_transfer.
|
||||
|
||||
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_required_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_required_tokens = len(tokens)
|
||||
|
||||
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer)
|
||||
ret_mask[start:end] = True
|
||||
|
||||
if keys:
|
||||
# Transpose the keys into layer major format
|
||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
if not first_flag:
|
||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
first_flag = False
|
||||
yield None
|
||||
else:
|
||||
# If no cache are found, we still need to yield to avoid
|
||||
# `StopIteration`
|
||||
for layer_id in range(self.num_layers):
|
||||
yield None
|
||||
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {len(tokens)} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the storage backend which
|
||||
will be passed into the gpu_connector.
|
||||
|
||||
return: A generator that yields None. In the first iteration, the
|
||||
generator allocates the memory objects for all layers and moves
|
||||
the KV cache of the first layer from GPU to CPU. In the next
|
||||
iterations, it moves the KV cache of layer i from GPU to the memory
|
||||
objects (on CPU) and puts the memory objects of layer i-1 to the
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_stored_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_stored_tokens = len(tokens)
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
||||
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.load_async else set())
|
||||
|
||||
logger.debug(
|
||||
"Number of completed KV cache send requests: %d, receive "
|
||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def wait_layer_transfer_finish(self):
|
||||
time.sleep(10)
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
res = self.m_store.batch_exists(
|
||||
keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def lookup_scheduler(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, self.tp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@0", f"@{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
res = self.m_store.batch_exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) *
|
||||
num_block] # type: ignore[index]
|
||||
for i in range(self.tp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the cache engine and free all the resources"""
|
||||
self.m_store.close()
|
||||
126
vllm_npu/distributed/mooncake/mooncake_store.py
Normal file
126
vllm_npu/distributed/mooncake/mooncake_store.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Standard
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
from mooncake.store import ReplicateConfig # type: ignore
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.utils import get_ip, logger
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import MooncakeEngineKey
|
||||
from vllm_npu.distributed.mooncake.transfer_engine import get_global_te
|
||||
|
||||
from .config_data import MooncakeStoreConfig
|
||||
|
||||
METADATA_BYTES_LEN = 24
|
||||
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
|
||||
|
||||
|
||||
class Mooncakestore():
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank_local
|
||||
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
|
||||
if not all_device_ids:
|
||||
device_ids_list = list(
|
||||
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
|
||||
else:
|
||||
device_ids_list = list(map(int, all_device_ids.split(',')))
|
||||
assert len(device_ids_list) > tp_rank
|
||||
device_id = device_ids_list[tp_rank]
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
|
||||
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
|
||||
":npu_" + str(device_id)
|
||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
else:
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = get_global_te(local_hostname, device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def exists(self, key: MooncakeEngineKey) -> bool:
|
||||
return self.store.is_exist(key.to_string()) == 1
|
||||
|
||||
def batch_exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def register_buffer(self, ptr, length):
|
||||
return self.store.register_buffer(ptr, length)
|
||||
|
||||
def get_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(
|
||||
keys, addrs, sizes, True)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}. {e}")
|
||||
|
||||
def put_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
config = ReplicateConfig()
|
||||
config.preferred_segment = self.local_seg
|
||||
config.prefer_alloc_in_same_node = True
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys, addrs, sizes, config)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
expect_res = sum(size)
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
res = self.store.batch_get_into_ascend(key_str, addr, size)
|
||||
if res[0] != expect_res:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
except Exception:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
return res
|
||||
|
||||
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
ret = self.store.batch_put_from_ascend(key_str, addr, size)
|
||||
if ret[0] != 0:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
except Exception:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
|
||||
return ret
|
||||
|
||||
def close(self):
|
||||
self.store.close()
|
||||
logger.info("Closed the mooncake store connection")
|
||||
494
vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py
Normal file
494
vllm_npu/distributed/mooncake/mooncake_store_connector_v1.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import logger, make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
from vllm_npu.distributed.mooncake.config_data import (
|
||||
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
|
||||
from vllm_npu.distributed.mooncake.mooncake_engine import MooncakeEngine
|
||||
|
||||
|
||||
class MooncakeConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
|
||||
vllm_config, self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = MooncakeEngine(
|
||||
vllm_config,
|
||||
self.use_layerwise,
|
||||
)
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0 and self.kv_role != "kv_consumer":
|
||||
self.lookup_server = MooncakeLookupServer(
|
||||
self.connector_worker, vllm_config, self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._get_connector_metadata(),
|
||||
MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeStoreConnector does not do layerwise saving."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
if self.use_layerwise:
|
||||
self.connector_worker.wait_layer_transfer_finish()
|
||||
return
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
meta = self._get_connector_metadata()
|
||||
done_sending, done_recving = self.connector_worker.get_finished()
|
||||
sended_and_finished: set[str] = set()
|
||||
for item in list(self.sended_but_unfinished_reqs):
|
||||
if item not in meta.unfinished_request_ids:
|
||||
sended_and_finished.add(item)
|
||||
self.sended_but_unfinished_reqs.remove(item)
|
||||
for item in done_sending:
|
||||
if item in meta.unfinished_request_ids:
|
||||
self.sended_but_unfinished_reqs.add(item)
|
||||
else:
|
||||
sended_and_finished.add(item)
|
||||
|
||||
return sended_and_finished, done_recving
|
||||
|
||||
|
||||
def get_zmq_rpc_path_mooncake(
|
||||
vllm_config: Optional["VllmConfig"] = None, ) -> str:
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
if vllm_config is not None:
|
||||
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"mooncake_rpc_port", 0)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
|
||||
|
||||
|
||||
class MooncakeStoreConnectorV1Scheduler:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||
self.use_layerwise = use_layerwise
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.client = MooncakeLookupClient(
|
||||
vllm_config) if self.kv_role != "kv_consumer" else None
|
||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_load", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
# request_id -> (vllm cached tokes, mooncake cached tokens)
|
||||
self.load_specs: dict[str, LoadSpec] = {}
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
# request_id -> full_token_ids
|
||||
self._request_trackers: dict[str, RequestTracker] = {}
|
||||
# Whether to discard partial chunks
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True))
|
||||
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._unfinished_request_ids: set[str] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check for external KV cache hit.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
|
||||
return 0, False
|
||||
|
||||
if self._discard_partial_chunks:
|
||||
token_block_end = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
token_ids = torch.tensor(
|
||||
request.prompt_token_ids[:token_block_end])
|
||||
else:
|
||||
token_ids = torch.tensor(request.prompt_token_ids)
|
||||
|
||||
num_external_hit_tokens = self.client.lookup( # type: ignore[union-attr]
|
||||
token_ids)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
|
||||
logger.info(
|
||||
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
|
||||
request.request_id,
|
||||
request.num_tokens,
|
||||
num_external_hit_tokens,
|
||||
need_to_allocate,
|
||||
)
|
||||
|
||||
if need_to_allocate <= 0:
|
||||
return 0, False
|
||||
|
||||
self.load_specs[request.request_id] = LoadSpec(
|
||||
vllm_cached_tokens=num_computed_tokens,
|
||||
mooncake_cached_tokens=num_external_hit_tokens,
|
||||
can_load=False,
|
||||
)
|
||||
|
||||
return need_to_allocate, self.load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after temporary buffer alloc.
|
||||
|
||||
For SharedStorageConnector, update _request_needs_load
|
||||
if the CacheManager this allocated blocks for us.
|
||||
"""
|
||||
local_block_ids = []
|
||||
if num_external_tokens > 0:
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
|
||||
self._unfinished_requests[request.request_id] = (request,
|
||||
local_block_ids)
|
||||
self._unfinished_request_ids.add(request.request_id)
|
||||
if request.request_id not in self.load_specs:
|
||||
# No KV tokens from external KV cache, return
|
||||
return
|
||||
|
||||
if num_external_tokens == 0:
|
||||
# No need to load anything
|
||||
self.load_specs[request.request_id].can_load = False
|
||||
return
|
||||
|
||||
assert (
|
||||
num_external_tokens > 0 and num_external_tokens
|
||||
== self.load_specs[request.request_id].mooncake_cached_tokens -
|
||||
self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||
f" for request {request.request_id}")
|
||||
|
||||
self.load_specs[request.request_id].can_load = True
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""Attach the connector metadata to the request object.
|
||||
|
||||
This function should NOT modify other fields in the scheduler_output
|
||||
except the `kv_connector_metadata` field.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
force_skip_save = self.kv_role == "kv_consumer"
|
||||
|
||||
for finished_req_id in scheduler_output.finished_req_ids:
|
||||
self._request_trackers.pop(finished_req_id, None)
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.discard(finished_req_id)
|
||||
|
||||
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
|
||||
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
# Right now, we only load KV for new requests
|
||||
load_spec = self.load_specs.pop(request.req_id, None)
|
||||
num_tokens_to_compute = (
|
||||
request.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||
request_tracker = RequestTracker.from_new_request(
|
||||
request, num_tokens_to_compute)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
if isinstance(cached_reqs, list) and not force_skip_save:
|
||||
for i, req in enumerate(cached_reqs):
|
||||
request_tracker = self._request_trackers[req.req_id]
|
||||
request_tracker.update(req.new_token_ids, req.new_block_ids)
|
||||
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(req.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
elif not force_skip_save:
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = len(request_tracker.token_ids)
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if not new_block_ids:
|
||||
continue
|
||||
request_tracker.update(new_token_ids, new_block_ids)
|
||||
# decode not save
|
||||
if len(request_tracker.token_ids) > len(
|
||||
request.prompt_token_ids):
|
||||
continue
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
request_ids = [
|
||||
req.req_id for req in scheduler_output.scheduled_new_reqs
|
||||
]
|
||||
for request_id, (request,
|
||||
block_ids) in self._unfinished_requests.items():
|
||||
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
||||
load_spec = self.load_specs.pop(request_id, None)
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.mooncake_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size
|
||||
!= 0) and (num_tokens_to_compute
|
||||
== len(request.prompt_token_ids) - 1):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
self._request_trackers[request_id] = request_tracker
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=None,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer":
|
||||
return False, None
|
||||
tracker = self._request_trackers.get(request.request_id)
|
||||
if tracker is not None and tracker.num_saved_tokens <= 0:
|
||||
return False, None
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(block_ids), request.request_id)
|
||||
return delay_free_blocks, None
|
||||
|
||||
|
||||
class MooncakeLookupClient:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REQ, # type: ignore[attr-defined]
|
||||
bind=False,
|
||||
)
|
||||
|
||||
def lookup(self, token_ids: torch.Tensor) -> int:
|
||||
request = self.encoder.encode(token_ids)
|
||||
self.socket.send_multipart(request, copy=False)
|
||||
resp = self.socket.recv()
|
||||
result = int.from_bytes(resp, "big")
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
|
||||
|
||||
class MooncakeLookupServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mooncake_engine: MooncakeEngine,
|
||||
vllm_config: "VllmConfig",
|
||||
use_layerwise: bool,
|
||||
):
|
||||
self.decoder = MsgpackDecoder(torch.Tensor)
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REP, # type: ignore[attr-defined]
|
||||
bind=True,
|
||||
)
|
||||
|
||||
self.mooncake_engine = mooncake_engine
|
||||
self.running = True
|
||||
|
||||
def process_request():
|
||||
while self.running:
|
||||
frames = self.socket.recv_multipart(copy=False)
|
||||
token_ids = self.decoder.decode(frames)
|
||||
result = self.mooncake_engine.lookup_scheduler(
|
||||
token_ids, use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
self.thread = threading.Thread(target=process_request, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
# TODO: close the thread!
|
||||
38
vllm_npu/distributed/mooncake/transfer_engine.py
Normal file
38
vllm_npu/distributed/mooncake/transfer_engine.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import ipaddress
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
|
||||
_global_te = None
|
||||
_global_te_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_global_te(hostname: str, device_name: Optional[str]):
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
raise RuntimeError(
|
||||
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
global _global_te
|
||||
if _global_te is None:
|
||||
with _global_te_lock:
|
||||
# Double-Checked Locking
|
||||
if _global_te is None:
|
||||
if TransferEngine is None:
|
||||
raise RuntimeError("mooncake is not available")
|
||||
transfer_engine = TransferEngine()
|
||||
device_name = device_name if device_name is not None else ""
|
||||
ret_value = transfer_engine.initialize(hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"ascend", device_name)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(
|
||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||
)
|
||||
_global_te = transfer_engine
|
||||
return _global_te
|
||||
1263
vllm_npu/distributed/mooncake_connector.py
Normal file
1263
vllm_npu/distributed/mooncake_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
1153
vllm_npu/distributed/mooncake_layerwise_connector.py
Normal file
1153
vllm_npu/distributed/mooncake_layerwise_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
196
vllm_npu/distributed/parallel_state.py
Normal file
196
vllm_npu/distributed/parallel_state.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
_OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
assert _MC2 is not None, ("mc2 group is not initialized")
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_otp_group() -> GroupCoordinator:
|
||||
assert _OTP is not None, (
|
||||
"output tensor parallel group is not initialized")
|
||||
return _OTP
|
||||
|
||||
|
||||
def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
assert _LMTP is not None, (
|
||||
"lm head tensor parallel group is not initialized")
|
||||
return _LMTP
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def get_p_tp_group() -> GroupCoordinator:
|
||||
assert _P_TP is not None, (
|
||||
"distributed prefill tensor parallel group is not initialized")
|
||||
return _P_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
|
||||
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
if model_parallel_initialized():
|
||||
return
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
|
||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
global _P_TP
|
||||
assert _P_TP is None, (
|
||||
"distributed prefill tensor parallel group is already initialized")
|
||||
prefill_tensor_model_parallel_size = pd_tp_ratio
|
||||
# divide alltoall groups
|
||||
if pd_head_ratio > 1 and get_current_vllm_config(
|
||||
).kv_transfer_config.is_kv_producer:
|
||||
num_head_replica = get_ascend_config().num_head_replica
|
||||
remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio
|
||||
if num_head_replica <= 1:
|
||||
group_ranks = all_ranks.view(
|
||||
-1, prefill_tensor_model_parallel_size).unbind(0)
|
||||
else:
|
||||
group_ranks = all_ranks.clone().view(
|
||||
parallel_config.data_parallel_size, -1,
|
||||
num_head_replica) # [DP_size, num_head, num_head_replica]
|
||||
group_ranks = group_ranks.permute(0, 2, 1)
|
||||
group_ranks = group_ranks.reshape(
|
||||
-1,
|
||||
group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
||||
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
||||
group_ranks = group_ranks.unsqueeze(-1).view(
|
||||
parallel_config.data_parallel_size, num_head_replica, -1,
|
||||
alltoall_group_size
|
||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||
group_ranks = group_ranks.reshape(-1,
|
||||
alltoall_group_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
local_rank = get_world_group().local_rank
|
||||
num = next(
|
||||
(i for i, ranks in enumerate(group_ranks) if local_rank in ranks),
|
||||
None)
|
||||
_P_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name=f"p_tp_{num}")
|
||||
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
_MC2 = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mc2")
|
||||
if envs_ascend.vllm_npu_ENABLE_MLP_OPTIMIZE:
|
||||
global _MLP_TP
|
||||
assert _MLP_TP is None, (
|
||||
"mlp tensor model parallel group is already initialized")
|
||||
|
||||
mlp_tp = parallel_config.data_parallel_size
|
||||
|
||||
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
||||
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
||||
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_MLP_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
# If oproj tensor parallel size is set, we will create a group for it.
|
||||
otp_size = get_ascend_config().oproj_tensor_parallel_size
|
||||
if otp_size is not None:
|
||||
group_ranks = []
|
||||
global _OTP
|
||||
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
ranks = list(range(i * otp_size, (i + 1) * otp_size))
|
||||
group_ranks.append(ranks)
|
||||
_OTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="otp")
|
||||
|
||||
lmhead_tensor_parallel_size = get_ascend_config(
|
||||
).lmhead_tensor_parallel_size
|
||||
if lmhead_tensor_parallel_size is not None:
|
||||
group_ranks = []
|
||||
global _LMTP
|
||||
num_lmhead_tensor_parallel_groups: int = (world_size //
|
||||
lmhead_tensor_parallel_size)
|
||||
for i in range(num_lmhead_tensor_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * lmhead_tensor_parallel_size,
|
||||
(i + 1) * lmhead_tensor_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
_LMTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().world_size
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_rank():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
_MC2.destroy()
|
||||
_MC2 = None
|
||||
|
||||
global _MLP_TP
|
||||
if _MLP_TP:
|
||||
_MLP_TP.destroy()
|
||||
_MLP_TP = None
|
||||
|
||||
global _LMTP
|
||||
if _LMTP:
|
||||
_LMTP.destroy()
|
||||
_LMTP = None
|
||||
|
||||
global _OTP
|
||||
if _OTP:
|
||||
_OTP.destroy()
|
||||
_OTP = None
|
||||
|
||||
global _P_TP
|
||||
if _P_TP:
|
||||
_P_TP.destroy()
|
||||
_P_TP = None
|
||||
61
vllm_npu/distributed/utils.py
Normal file
61
vllm_npu/distributed/utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_npu.distributed.parallel_state import get_p_tp_group
|
||||
|
||||
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
value: torch.TensorType):
|
||||
if pd_tp_ratio <= 1:
|
||||
return None, None
|
||||
elif key is None or value is None:
|
||||
raise ValueError("key or value is None")
|
||||
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
|
||||
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
|
||||
return k_output, v_output
|
||||
|
||||
|
||||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||||
num_kv_heads = input_tensor.size(1)
|
||||
output_tensor = torch.zeros_like(input_tensor)
|
||||
dist.all_to_all_single(output_tensor,
|
||||
input_tensor,
|
||||
group=get_p_tp_group().device_group)
|
||||
input_tensor = 0
|
||||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||||
output_tensor = 0
|
||||
return result
|
||||
|
||||
|
||||
def rearrange_output(base_output: torch.Tensor, cut_num: int,
|
||||
num_kv_heads: int):
|
||||
size_0 = base_output.size(0)
|
||||
if size_0 % cut_num != 0:
|
||||
raise ValueError(
|
||||
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
|
||||
)
|
||||
chunk_size = size_0 // cut_num
|
||||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||||
transposed = reshaped.transpose(0, 1)
|
||||
return transposed.contiguous().view(size_0, num_kv_heads, -1)
|
||||
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
|
||||
def get_transfer_timeout_value():
|
||||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||||
if len(ascend_transfer_timeout) > 0:
|
||||
return int(ascend_transfer_timeout)
|
||||
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
|
||||
'20')) # type: ignore
|
||||
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
|
||||
'7')) # type: ignore
|
||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
||||
3000)
|
||||
Reference in New Issue
Block a user