This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

@@ -1 +1,40 @@
"""Ascend NPU distributed communication (HCCL)."""
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.distributed.kv_transfer.kv_connector.factory import \
KVConnectorFactory
def register_connector():
KVConnectorFactory.register_connector(
"LLMDataDistCMgrConnector",
"vllm_npu.distributed.llmdatadist_c_mgr_connector",
"LLMDataDistCMgrConnector")
KVConnectorFactory.register_connector(
"MooncakeConnectorV1", "vllm_npu.distributed.mooncake_connector",
"MooncakeConnector")
KVConnectorFactory.register_connector(
"MooncakeConnectorStoreV1",
"vllm_npu.distributed.mooncake.mooncake_store_connector_v1",
"MooncakeConnectorV1")
KVConnectorFactory.register_connector(
"MooncakeLayerwiseConnector",
"vllm_npu.distributed.mooncake_layerwise_connector",
"MooncakeLayerwiseConnector")

View File

@@ -1,42 +1,46 @@
"""
NPUCommunicator — HCCL-based device communicator for Ascend NPU.
Extends ``DeviceCommunicatorBase`` with NPU-specific collective
operations using the HCCL backend.
"""
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, Optional
import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
class NPUCommunicator(DeviceCommunicatorBase):
"""Device communicator for Ascend NPU using HCCL."""
def __init__(
self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = "",
):
def __init__(self,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
import torch_npu # noqa: F401
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()
def all_to_all(
self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None,
) -> torch.Tensor:
"""All-to-all communication for NPU tensors."""
def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
@@ -53,22 +57,17 @@ class NPUCommunicator(DeviceCommunicatorBase):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(
torch.empty(
tensor_shape,
dtype=input_.dtype,
device=input_.device,
)
)
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))
else:
input_list = [
t.contiguous()
for t in torch.tensor_split(
input_, self.world_size, scatter_dim
)
t.contiguous() for t in torch.tensor_split(
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i])
for i in range(self.world_size)
torch.empty_like(input_list[i]) for i in range(self.world_size)
]
dist.all_to_all(output_list, input_list, group=self.device_group)

View File

@@ -0,0 +1,471 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import queue
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Sequence
import torch
from vllm.attention import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MLAAttentionSpec)
from vllm_npu.ascend_config import get_ascend_config
from vllm_npu.distributed.cpu_offload_manager.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@dataclass
class ReqMeta:
gpu_block_ids: list[int]
cpu_block_ids: list[int]
num_scheduled_tokens: int
num_computed_tokens: int
num_gpu_computed_tokens: int
num_cpu_computed_tokens: int
def update(self, other: "ReqMeta"):
self.gpu_block_ids.extend(other.gpu_block_ids)
self.cpu_block_ids.extend(other.cpu_block_ids)
self.num_scheduled_tokens = other.num_scheduled_tokens
self.num_computed_tokens = other.num_computed_tokens
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
@dataclass
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
requests: dict[str, ReqMeta]
finished_req_ids: set[str]
class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[
CPUOffloadingConnectorScheduler] = None
self.connector_worker: Optional[
CPUOffloadingConnectorWorker] = None
elif role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = CPUOffloadingConnectorScheduler(
vllm_config)
self.connector_worker = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
if self.connector_worker is not None:
assert isinstance(connector_metadata,
CPUOffloadingConnectorMetadata)
self.connector_worker.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
assert self.connector_worker is not None
self.connector_worker.clear_connector_metadata()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if self.connector_worker is not None:
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
if self.connector_worker is not None:
self.connector_worker.start_load_kv()
def wait_for_layer_load(self, layer_name: str) -> None:
if self.connector_worker is not None:
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
pass
def wait_for_save(self):
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(), None
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
if self.connector_scheduler is not None:
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
if self.connector_scheduler is not None:
return self.connector_scheduler.update_state_after_alloc(request)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
if self.connector_scheduler is not None:
return self.connector_scheduler.build_connector_meta(
scheduler_output)
return KVConnectorMetadata()
def request_finished(
self, request: "Request",
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
if self.connector_scheduler is not None:
self.connector_scheduler.request_finished(request)
return True, None
class CPUOffloadingConnectorScheduler:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorScheduler")
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.use_mla = vllm_config.model_config.use_mla
self.num_gpu_computed_tokens: dict[str, int] = {}
self.num_cpu_computed_tokens: dict[str, int] = {}
self.allocated_req_ids: set[str] = set()
self.finished_req_ids: list[str] = []
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.zmq_rpc_client.call("post_init")
if vllm_config.kv_transfer_config is not None:
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
"swap_in_threshold", 0)
else:
self.swap_in_threshold = 0
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
def get_num_new_matched_tokens(
self, ori_request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
"get_matched_num_and_touch", request)
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
self.num_cpu_computed_tokens[
request.request_id] = num_cpu_computed_tokens
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
return num_cpu_computed_tokens - num_computed_tokens, load_async
else:
return 0, load_async
def update_state_after_alloc(self, request: "Request"):
self.allocated_req_ids.add(request.request_id)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
num_tokens = {}
# process scheduled_new_reqs
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
num_tokens[req_id] = (
req.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
# process scheduled_cached_reqs
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, req_id in enumerate(cached_reqs.req_ids):
num_tokens[req_id] = (
cached_reqs.num_computed_tokens[idx] +
scheduler_output.num_scheduled_tokens[req_id])
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
self.allocated_req_ids -
scheduler_output.num_scheduled_tokens.keys())
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
num_tokens,
unallocated_req_ids)
metadata = CPUOffloadingConnectorMetadata(
requests={},
finished_req_ids=set(self.finished_req_ids),
)
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
gpu_block_ids = req.block_ids[0]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_computed_tokens=req.num_computed_tokens,
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
for idx, req_id in enumerate(cached_reqs.req_ids):
gpu_block_ids = cached_reqs.new_block_ids[idx]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
self.num_gpu_computed_tokens.clear()
self.num_cpu_computed_tokens.clear()
self.allocated_req_ids.clear()
self.finished_req_ids.clear()
return metadata
def request_finished(self, ori_request: "Request"):
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
self.finished_req_ids.append(request.request_id)
# inform metadata server to record request, and free it after finish sending
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
request)
class CPUOffloadingConnectorWorker:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorWorker")
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.pp_rank = get_pp_group().rank_in_group
self.tp_group = get_tp_group()
self.tp_rank = self.tp_group.rank_in_group
self.tp_world_size = self.tp_group.world_size
self.use_mla = vllm_config.model_config.use_mla
self.requests: dict[str, ReqMeta] = {}
self.load_stream = torch.npu.Stream()
self.save_stream = torch.npu.Stream()
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.load_block_mapping: list[tuple[int, int]] = []
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
self.save_output_queue: queue.Queue[str] = queue.Queue()
self.save_thread = threading.Thread(target=self._save_listener)
self.save_thread.start()
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
# all dp shared the same metadata server, only start the process on data_rank 0
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
config = VllmConfig()
config.cache_config = vllm_config.cache_config
config.parallel_config = vllm_config.parallel_config
config.kv_transfer_config = vllm_config.kv_transfer_config
self.init_metadata_server(config)
self._wait_for_metadata_process_start()
def init_metadata_server(self, vllm_config: VllmConfig):
self.metadata_thread = threading.Thread(
target=MetadataServerProc.run_metadata_server,
args=(vllm_config, ),
)
self.metadata_thread.daemon = True
self.metadata_thread.start()
def _wait_for_metadata_process_start(self):
# TODO: wait for metadata server to start, add a rpc to check if ready
while True:
try:
if self.zmq_rpc_client.call("ready"):
break
except Exception as e:
logger.info(f"wait for metadata server to start, error: {e}")
time.sleep(1)
def bind_connector_metadata(
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
for req_id, req in connector_metadata.requests.items():
if req_id in self.requests:
self.requests[req_id].update(req)
req = self.requests[req_id]
else:
self.requests[req_id] = req
for i in range(req.num_gpu_computed_tokens // self.block_size,
req.num_computed_tokens // self.block_size):
self.load_block_mapping.append(
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
for req_id in connector_metadata.finished_req_ids:
if req_id in self.requests:
self.save_input_queue.put((req_id, self.requests[req_id]))
def clear_connector_metadata(self) -> None:
self.load_block_mapping.clear()
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
self.gpu_kv_caches = kv_caches
model_config = self.vllm_config.model_config
mla_config: Optional[MLAConfig] = None
if model_config.use_mla:
mla_config = MLAConfig(
model_config.hf_text_config.kv_lora_rank,
model_config.hf_text_config.qk_rope_head_dim)
self.cpu_kv_caches = list(
self.zmq_rpc_client.call(
"init_cpu_kv_caches",
self.pp_rank,
self.tp_rank,
get_kv_cache_spec(self.vllm_config),
mla_config,
).values())
def start_load_kv(self) -> None:
self.current_layer = 0
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
self.load_kv_layer(0)
def wait_for_layer_load(self) -> None:
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
self.load_stream.synchronize()
self.current_layer += 1
self.load_kv_layer(self.current_layer)
def load_kv_layer(self, layer: int):
if layer == len(self.gpu_kv_caches):
return
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
cpu_kv_caches = self.cpu_kv_caches[layer]
with torch.npu.stream(self.load_stream):
for cpu_block_id, gpu_block_id in self.load_block_mapping:
for gpu_layer_part, cpu_layer_part in zip(
gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(
cpu_layer_part[cpu_block_id], non_blocking=True)
def get_finished(self) -> set[str]:
done_sending: set[str] = set()
while True:
try:
id = self.save_output_queue.get_nowait()
except queue.Empty:
break
done_sending.add(id)
for id in done_sending:
del self.requests[id]
if self.tp_world_size == 1:
return done_sending
if self.tp_rank == 0:
for req_id in done_sending:
self.done_sending_count[req_id] += 1
other_ranks_finished_ids: list[str] = []
for i in range(1, self.tp_world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
self.done_sending_count[req_id] += 1
all_done_sending: set[str] = set()
for req_id in list(self.done_sending_count.keys()):
if self.done_sending_count[req_id] == self.tp_world_size:
del self.done_sending_count[req_id]
all_done_sending.add(req_id)
# release cpu_kv_cache after request sending finished
# to avoid rpc blocking, use thread to call rpc asynchronously
sending_finished_thread = threading.Thread(
target=self._sending_finished, args=(all_done_sending, ))
sending_finished_thread.daemon = True
sending_finished_thread.start()
return all_done_sending
else:
self.tp_group.send_object(done_sending, dst=0)
return done_sending
def _sending_finished(self, all_done_sending):
for req_id in all_done_sending:
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
def _save_listener(self):
save_block_mapping = []
while True:
req_id, req = self.save_input_queue.get()
for i in range(
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) //
self.block_size, len(req.cpu_block_ids))):
save_block_mapping.append(
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
with torch.npu.stream(self.save_stream):
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
# non-MLA: kv_layer is list[tensor], typically means [k, v].
if self.use_mla:
start, step = self.tp_rank, self.tp_world_size
else:
start, step = 0, 1
for i in range(start, len(save_block_mapping), step):
gpu_block_id, cpu_block_id = save_block_mapping[i]
for cpu_kv_caches, gpu_kv_caches in zip(
self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(
cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(
gpu_layer_part[gpu_block_id],
non_blocking=True)
self.save_stream.synchronize()
self.save_output_queue.put(req_id)
save_block_mapping.clear()
# Copied from vllm_npu/worker/model_runner_v1.py.
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
forward_ctx = vllm_config.compilation_config.static_forward_context
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
ascend_config = get_ascend_config()
use_sfa = ascend_config.use_sfa
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
if use_mla and not use_sfa:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@@ -0,0 +1,202 @@
import time
from collections import defaultdict
from typing import Optional
from vllm.utils import logger, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
PrefixCachingMetrics)
from vllm.v1.core.single_type_kv_cache_manager import \
get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
self.time_sec = int(time.time())
def log(self):
current_time_sec = int(time.time())
# Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%",
self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def update(self, num_tokens, num_computed_tokens):
# Note the function is called by scheduler
if self.log_stats and self.enable_prefix_caching:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
def set_cache_stats(self, num_tokens, num_computed_tokens):
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.hits = num_computed_tokens
self.prefix_cache_stats.queries = num_tokens
self.prefix_cache_stats.requests = 1
class CPUKVCacheManager:
def __init__(
self,
kv_cache_spec: KVCacheSpec,
num_cpu_blocks: int,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.block_size = kv_cache_spec.block_size
self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True,
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
)
# Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[
str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
log_stats=True)
# Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None):
return 0, False
request_id = request.request_id
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request_id]
if not block_hashes:
block_hashes = request.block_hashes
self.req_to_block_hashes[request_id] = block_hashes
max_cache_hit_length = request.num_tokens - 1
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.single_type_manager.kv_cache_spec,
use_eagle=self.use_eagle,
)
num_computed_tokens = len(computed_blocks[0]) * self.block_size
self.req_to_computed_blocks[request_id] = computed_blocks[0]
# We should touch these blocks in the concurrent scenarios.
self.block_pool.touch(computed_blocks)
# cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log()
return num_computed_tokens, False
def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks:
self.single_type_manager.block_pool.free_blocks(
reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int],
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids:
self._free_slots(request_id)
req_to_new_blocks = {}
for request_id, num_tokens in req_to_num_tokens.items():
if self.req_failed_to_allocate[request_id]:
continue
new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
))
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True
continue
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request_id, new_computed_blocks)
# Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks(
request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [
block.block_id for block in new_computed_blocks + new_blocks
]
return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request):
logger.debug(
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str):
logger.debug(
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
if request_id not in self.req_to_free:
logger.Error(
f"request {request_id} not in req_to_free, maybe bug!")
return
request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]:
self.single_type_manager.cache_blocks(
request,
self.req_to_num_tokens[request_id],
)
self._free_slots(request_id)
logger.debug(
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id]
def _free_slots(self, request_id: str):
# This function is designed to be reentrant.
self._release_ahead_touch(request_id)
self.single_type_manager.free(request_id)
self.req_to_block_hashes.pop(request_id, None)
self.req_to_computed_blocks.pop(request_id, None)
self.req_failed_to_allocate.pop(request_id, None)
self.req_to_num_tokens.pop(request_id, None)

View File

@@ -0,0 +1,269 @@
import math
import os
import pickle
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.config import KVTransferConfig, VllmConfig
from vllm.utils import get_dtype_size, logger, make_zmq_socket
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_npu.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager
@dataclass
class MLAConfig:
nope_dim: int
rope_dim: int
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if vllm_config.kv_transfer_config is not None:
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
return None
class MetadataServer:
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=f"worker-{os.getpid()}"):
logger.info(f"metadata client for worker {identity} started")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(request))
_ = self.socket.recv()
response = pickle.loads(self.socket.recv())
result, error = response
if error:
logger.exception(f"call metadata sever error: {error}")
raise error
if func_name == "init_cpu_kv_caches":
(memory_dict, layer_size, layer_dtype, mla_config) = result
# shared_memory_dict is recorded in self to close
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
def __del__(self):
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
for shm in self.shared_memory_dict.values():
shm.close()
def __init__(self, vllm_config: VllmConfig):
self.world_size = vllm_config.parallel_config.world_size
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
"ready": self.ready,
}
self.shared_memory = {} # type: ignore
self.num_cpu_blocks = -1
@staticmethod
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
try:
existing_shm = SharedMemory(name=name, create=False)
existing_shm.close()
existing_shm.unlink()
except FileNotFoundError:
pass
return SharedMemory(name=name, create=True, size=size)
def ready(self):
return True
def init_cpu_kv_caches(
self,
pp_rank: int,
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
# mla shares the same kv cache among different tp
if layer.use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if layer.use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
return self.shared_memory[(pp_rank, tp_rank)]
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
def serve_step(self):
client_id = self.socket.recv()
_ = self.socket.recv()
raw_msg = self.socket.recv()
try:
func_name, args, kwargs = pickle.loads(raw_msg)
except Exception as e:
response = (None, Exception(f"Invalid request: {str(e)}"))
else:
if func_name in self.functions:
try:
result = self.functions[func_name](*args, **kwargs)
response = (result, None) # type: ignore
except Exception as e:
logger.exception(f"metadata execute error: {e}")
response = (None, e) # type: ignore
else:
response = (None, NameError(f"Function {func_name} not found"))
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(response))
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
for shm in cached[0].values():
shm.close()
shm.unlink()
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
return
shutdown_requested = False
def _signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")
while True:
metadata_server.serve_step()
except SystemExit:
logger.info("Metadata server exiting.")
raise
except Exception as e:
logger.exception(f"Metadata server error: {e}.")
raise e
finally:
if metadata_server is not None:
metadata_server.shutdown()

View File

@@ -0,0 +1,165 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import logger
from vllm_npu.distributed.device_communicators.pyhccl_wrapper import (
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
hcclRedOpTypeEnum, hcclUniqueId)
from vllm_npu.utils import current_stream
class PyHcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyHcclCommunicator to. If None,
it will be bind to f"npu:{local_rank}".
library_path: the path to the HCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.HCCL, (
"PyHcclCommunicator should be attached to a non-HCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
return
try:
self.hccl = HCCLLibrary(library_path)
except Exception:
# disable because of missing HCCL library
# e.g. in a non-NPU environment
self.available = False
self.disabled = True
return
self.available = True
self.disabled = False
logger.info("vLLM is using pyhccl")
if isinstance(device, int):
device = torch.device(f"npu:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
if self.rank == 0:
# get the unique id from HCCL
with torch.npu.device(device):
self.unique_id = self.hccl.hcclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = hcclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
# hccl communicator and stream will use this device
# `torch.npu.device` is a context manager that changes the
# current npu device to the specified one
with torch.npu.device(device):
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
return None
# hccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
hcclDataTypeEnum.from_torch(in_tensor.dtype),
hcclRedOpTypeEnum.from_torch(op), self.comm,
aclrtStream_t(stream.npu_stream))
return out_tensor
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this hccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
if src == self.rank:
buffer = buffer_type(tensor.data_ptr())
else:
buffer = buffer_type(tensor.data_ptr())
self.hccl.hcclBroadcast(buffer, tensor.numel(),
hcclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, aclrtStream_t(stream.npu_stream))

View File

@@ -0,0 +1,253 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import logger
from vllm_npu.utils import find_hccl_library
# export types and functions from hccl to Python ===
# for the original hccl definition, please check
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
hcclResult_t = ctypes.c_int
hcclComm_t = ctypes.c_void_p
class hcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 4108)]
aclrtStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
hcclDataType_t = ctypes.c_int
class hcclDataTypeEnum:
hcclInt8 = 0
hcclInt16 = 1
hcclInt32 = 2
hcclFloat16 = 3
hcclFloat32 = 4
hcclInt64 = 5
hcclUint64 = 6
hcclUint8 = 7
hcclUint16 = 8
hcclUint32 = 9
hcclFloat64 = 10
hcclBfloat16 = 11
hcclInt128 = 12
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.hcclInt8
if dtype == torch.uint8:
return cls.hcclUint8
if dtype == torch.int32:
return cls.hcclInt32
if dtype == torch.int64:
return cls.hcclInt64
if dtype == torch.float16:
return cls.hcclFloat16
if dtype == torch.float32:
return cls.hcclFloat32
if dtype == torch.float64:
return cls.hcclFloat64
if dtype == torch.bfloat16:
return cls.hcclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
hcclRedOp_t = ctypes.c_int
class hcclRedOpTypeEnum:
hcclSum = 0
hcclProd = 1
hcclMax = 2
hcclMin = 3
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.hcclSum
if op == ReduceOp.PRODUCT:
return cls.hcclProd
if op == ReduceOp.MAX:
return cls.hcclMax
if op == ReduceOp.MIN:
return cls.hcclMin
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class HCCLLibrary:
exported_functions = [
# const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t,
[ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
]),
# HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [
buffer_type,
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_hccl_library()
try:
if so_file not in HCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
HCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = HCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load HCCL library from %s. "
"It is expected if you are not running on Ascend NPUs."
"Otherwise, the hccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file,
platform.platform())
raise e
if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
def hcclGetErrorString(self, result: hcclResult_t) -> str:
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
def HCCL_CHECK(self, result: hcclResult_t) -> None:
if result != 0:
error_str = self.hcclGetErrorString(result)
raise RuntimeError(f"HCCL error: {error_str}")
def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
ctypes.byref(unique_id)))
return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
rank: int) -> hcclComm_t:
comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
# `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
root: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
__all__ = [
"HCCLLibrary",
"hcclDataTypeEnum",
"hcclRedOpTypeEnum",
"hcclUniqueId",
"hcclComm_t",
"aclrtStream_t",
"buffer_type",
]

View File

@@ -0,0 +1,994 @@
import contextlib
import copy
import json
import math
import os
import threading
import time
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional, Tuple
import llm_datadist # type: ignore
import msgspec
import torch
import zmq
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
LLMException, LLMRole)
from vllm import envs
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.forward_context import ForwardContext
from vllm.utils import get_ip, logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
import vllm_npu.envs as envs_ascend
from vllm_npu.distributed.utils import get_transfer_timeout_value
from vllm_npu.utils import AscendSocVersion, get_ascend_soc_version
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
torch.float16: llm_datadist.DataType.DT_FLOAT16,
torch.bfloat16: llm_datadist.DataType.DT_BF16,
torch.float: llm_datadist.DataType.DT_FLOAT,
torch.float32: llm_datadist.DataType.DT_FLOAT,
torch.int8: llm_datadist.DataType.DT_INT8,
torch.int64: llm_datadist.DataType.DT_INT64,
torch.int32: llm_datadist.DataType.DT_INT32
}
class LLMDataDistCMgrEvent(Enum):
ReqForMetadata = 0
ReqForFinished = 1
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
super_pod_id: str
server_id: str
device_id: str
device_ip: str
super_device_id: str
cluster_id: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: str
engine_id: str
remote_tp_size: str
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
def add_new_req(self, request_id: str, local_block_ids: list[int],
kv_transfer_params: dict[str, Any]):
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"],
)
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
vllm_config, self.engine_id)
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(
self,
kv_caches: dict[
str, # type: ignore[override]
Tuple[torch.Tensor]]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished(finished_req_ids)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
LLMDataDistCMgrConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata, **kwargs) -> None:
"""LLMDataDistCMgrConnector does not save explicitly."""
pass
def wait_for_save(self):
"""LLMDataDistCMgrConnector does not save explicitly."""
pass
class LLMDataDistCMgrConnectorScheduler():
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
self.local_ip = get_ip()
# Can not retrieve the parallel config since it is not initialized.
self.local_dp_rank = None
self.tp_size = None
if vllm_config.parallel_config.data_parallel_external_lb:
dp_rank_local = vllm_config.parallel_config.data_parallel_rank
else:
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.port = dp_rank_local * tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0
# Note: We use the full token count as transmit data here.
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
return count, count > 0
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
num_externel_tokens: int):
params = request.kv_transfer_params
logger.debug(
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
)
if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port", "remote_tp_size")):
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
logger.warning("" \
f"Invalid KVTransferParams {params}, This request will be discard")
else:
assert num_externel_tokens == 0
params["do_remote_prefill"] = False
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = LLMDataDistCMgrConnectorMetadata()
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params)
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_send.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
params = request.kv_transfer_params
logger.debug(
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
# we just transfer any data that computed from prefill node
# note: there might be some issue on this, check it if there is any unexpected result
computed_block_ids = block_ids
delay_free_blocks = len(computed_block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = time.perf_counter(
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.local_ip,
remote_port=self.port,
remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size),
)
class LLMDataDistCMgrConnectorWorker():
"""
Implementation of Worker side methods
"""
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.kv_transfer_config is not None
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
self.local_rank_on_node = get_world_group().rank % (
vllm_config.parallel_config.data_parallel_size_local *
vllm_config.parallel_config.tensor_parallel_size)
self.local_rank = get_world_group().local_rank
if vllm_config.parallel_config.data_parallel_external_lb:
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank
else:
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_rank = get_tp_group().rank_in_group
self.rank = get_world_group().rank
self.local_ip = get_ip()
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
self.local_agent_metadata: Optional[
LLMDataDistCMgrAgentMetadata] = None
self.vllm_config = vllm_config
self.executor = ThreadPoolExecutor(1)
self.thread_lock = threading.Lock()
self.llm_datadist_role = None
self.llm_datadist_remote_role = None
if self.kv_transfer_config.kv_role == "kv_producer":
self.llm_datadist_role = LLMRole.PROMPT
self.llm_datadist_remote_role = LLMRole.DECODER
elif self.kv_transfer_config.kv_role == "kv_consumer":
self.llm_datadist_role = LLMRole.DECODER
self.llm_datadist_remote_role = LLMRole.PROMPT
else:
raise RuntimeError(
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
)
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
self.linked_cluster: dict[Any, Any] = {}
self.prefill_device_list: list[tuple[int, int]] = []
self.decode_device_list: list[tuple[int, int]] = []
global_rank_table = self.read_offline_rank_table()
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
self.local_agent_metadata.cluster_id)
self.init_llm_datadist()
self.finished_reqs: set[str] = set()
self.soc_info = get_ascend_soc_version()
# Set hccl deterministic for model execute
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
set[int]] = defaultdict(set)
self.reqs_to_send: dict[str, float] = {}
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
port = envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
url = f"tcp://{envs_ascend.vllm_npu_LLMDD_RPC_IP}:{port}"
msg_encoder = msgspec.msgpack.Encoder()
msg_decoder = msgspec.msgpack.Decoder()
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
logger.debug(f"Start to listen to address: {url}")
logger.debug(
f"The local agent metadata have {len(msg_to_send)} bytes here")
logger.info(
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
)
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
event.set()
while True:
identity, _, msg = sock.recv_multipart()
event_msg, decode_msg = msg_decoder.decode(msg)
event_msg = LLMDataDistCMgrEvent(event_msg)
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
if "cluster_id" in decode_msg:
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
logger.info(
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
)
sock.send_multipart((identity, b"", msg_to_send))
self.add_remote_agent(decode_msg)
else:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
)
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
finished_req_id = decode_msg[0]
with self.thread_lock:
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
if finished_req_id in self.reqs_to_send:
self.finished_reqs.add(finished_req_id)
del self.reqs_to_send[finished_req_id]
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
)
def init_llm_datadist(self):
assert self.local_agent_metadata is not None
llm_config = LLMConfig()
llm_config.device_id = self.local_rank
llm_config.sync_kv_timeout = get_transfer_timeout_value()
llm_config.enable_switch_role = True
llm_config.enable_cache_manager = True
llm_config.enable_remote_cache_accessible = True
llm_config_options = llm_config.generate_options()
self.llm_datadist.init(llm_config_options)
self.cache_manager = self.llm_datadist.cache_manager
logger.info(
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
)
def read_offline_rank_table(self):
assert (
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
with open(rank_table_path, "r", encoding="utf-8") as f:
global_rank_table = json.load(f)
decode_device_list = global_rank_table["decode_device_list"]
for decode_device in decode_device_list:
server_id = decode_device["server_id"]
device_id = decode_device["device_id"]
self.decode_device_list.append((server_id, device_id))
prefill_device_list = global_rank_table["prefill_device_list"]
for prefill_device in prefill_device_list:
server_id = prefill_device["server_id"]
device_id = prefill_device["device_id"]
self.prefill_device_list.append((server_id, device_id))
# global_rank_table = json.dumps(global_rank_table)
return global_rank_table
@staticmethod
def _get_visible_devices() -> Callable[[str], bool]:
"""
Return a test function that check if the given device ID is visible.
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
"""
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
if not visible_devices:
return lambda device_id: True
visible_device_list = visible_devices.split(",")
return lambda device_id: device_id in visible_device_list
def read_agent_metadata(self, global_rank_table):
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
devices_type_list = []
agent_metadata = None
if self.llm_datadist_role == LLMRole.PROMPT:
devices_type_list.append("prefill_device_list")
elif self.llm_datadist_role == LLMRole.DECODER:
devices_type_list.append("decode_device_list")
else:
devices_type_list.append("prefill_device_list")
devices_type_list.append("decode_device_list")
for device_type in devices_type_list:
device_list = global_rank_table[device_type]
device_list = [
d for d in device_list if d.get("server_id") == self.local_ip
and device_filter(d.get("device_id", ""))
]
if len(device_list) <= self.tp_rank:
continue
device_info = device_list[self.tp_rank]
super_pod_id_ = device_info.get("super_pod_id", None)
server_id_ = device_info["server_id"]
device_id_ = device_info["device_id"]
device_ip_ = device_info["device_ip"]
super_device_id_ = device_info.get("super_device_id", None)
cluster_id_ = int(device_info["cluster_id"])
agent_metadata = LLMDataDistCMgrAgentMetadata(
super_pod_id=super_pod_id_,
server_id=server_id_,
device_id=device_id_,
device_ip=device_ip_,
super_device_id=super_device_id_,
cluster_id=cluster_id_,
)
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
return agent_metadata
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
assert len(first_kv_cache_tuple) > 1
assert self.local_agent_metadata is not None
kv_cache_dtype = first_kv_cache.dtype
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = math.prod(block_shape)
self.cache_addr: list[int] = []
alignment = 2 * 1024 * 1024
if self.use_mla:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
k_normed = None
k_pe = None
for cache_or_caches in kv_caches.values():
assert len(cache_or_caches) > 1
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
cache_k_normed_addr_list.append(k_normed.data_ptr())
cache_k_pe_addr_list.append(k_pe.data_ptr())
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
cache_desc_k_normed = CacheDesc(
len(self.cache_addr[0]), [*k_normed.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_pe = CacheDesc(
len(self.cache_addr[1]), [*k_pe.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=0)
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=1)
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
try:
cache_k_normed = self.cache_manager.register_blocks_cache(
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
cache_k_pe = self.cache_manager.register_blocks_cache(
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
self.cache = (cache_k_normed, cache_k_pe)
logger.info("LLMDataDistWorker: End of register Paged Cache.")
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
elif self.use_sparse:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
cache_k_idx_addr_list = []
k_normed = None
k_pe = None
k_idx = None
for cache_or_caches in kv_caches.values():
assert len(cache_or_caches) > 1
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
1], cache_or_caches[2]
cache_k_normed_addr_list.append(k_normed.data_ptr())
cache_k_pe_addr_list.append(k_pe.data_ptr())
cache_k_idx_addr_list.append(k_idx.data_ptr())
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
cache_k_idx_addr_list)
cache_desc_k_normed = CacheDesc(
len(self.cache_addr[0]), [*k_normed.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_pe = CacheDesc(
len(self.cache_addr[1]), [*k_pe.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_idx = CacheDesc(
len(self.cache_addr[2]), [*k_idx.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=0)
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=1)
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=2)
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
cache_desc_k_idx)
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
cache_key_k_idx)
try:
cache_k_normed = self.cache_manager.register_blocks_cache(
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
cache_k_pe = self.cache_manager.register_blocks_cache(
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
cache_k_idx = self.cache_manager.register_blocks_cache(
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
logger.info("LLMDataDistWorker: End of register Paged Cache.")
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
else:
for cache_or_caches in kv_caches.values():
for cache in cache_or_caches:
base_addr = cache.data_ptr()
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
self.cache_addr.append(base_addr)
# register paged kv cache into the llm_cache manager
self.cache_desc = CacheDesc(
len(self.cache_addr), [*cache.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
self.cache_key = BlocksCacheKey(
cluster_id=int(self.local_agent_metadata.cluster_id))
logger.info(
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
)
try:
self.cache = self.cache_manager.register_blocks_cache(
self.cache_desc, self.cache_addr, self.cache_key)
logger.info(
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
self.ready_event = threading.Event()
self.metadata_agent_listener_t = threading.Thread(
target=self.listen_for_agent_metadata_req,
args=(self.ready_event, ),
daemon=True,
name="metadata_agent_listener")
self.metadata_agent_listener_t.start()
self.ready_event.wait()
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
futures = []
for req_id, meta in metadata.requests.items():
logger.debug(f"Start to transmit {req_id}")
future = self.executor.submit(
self._read_blocks,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_ip=meta.remote_host,
remote_port=int(meta.remote_port),
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
)
futures.append(future)
def handle_exception(future):
if future.exception():
logger.error(f"KV transfer task failed: {future.exception()}")
for future in futures:
future.add_done_callback(handle_exception)
self.reqs_to_send.update(metadata.reqs_to_send)
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
remote_cluster_id = metadata.cluster_id
if remote_cluster_id in self.linked_cluster:
logger.debug(
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
)
return remote_cluster_id
remote_super_pod_id = metadata.super_pod_id
remote_server_id = metadata.server_id
is_same_server = remote_server_id == self.local_agent_metadata.server_id
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
if self.llm_datadist_role == LLMRole.PROMPT:
prefill_metadata = self.local_agent_metadata
decode_metadata = metadata
else:
prefill_metadata = metadata
decode_metadata = self.local_agent_metadata
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
cluster_rank_info = {
prefill_metadata.cluster_id: 0,
decode_metadata.cluster_id: 1
}
rank_table = {}
rank_table["version"] = "1.2"
rank_table["server_count"] = "1" if is_same_server else "2"
rank_table["status"] = "completed"
# generate server_list for rank table
rank_table["server_list"] = [] # type: ignore[assignment]
decode_server_device_info = None
prefill_server_device_info = {
"device": [{
k: v
for k, v in [(
"device_id", prefill_metadata.device_id
), ("device_ip", prefill_metadata.device_ip
), ("super_device_id",
prefill_metadata.super_device_id), ("rank_id", "0")]
if v is not None
}],
"server_id":
prefill_metadata.server_id
}
if is_same_server:
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
{
k: v
for k, v in [(
"device_id", decode_metadata.device_id
), ("device_ip", decode_metadata.device_ip
), ("super_device_id",
decode_metadata.super_device_id), ("rank_id", "1")]
if v is not None
})
else:
decode_server_device_info = {
"device": [{
k: v
for k, v in [(
"device_id", decode_metadata.device_id
), ("device_ip", decode_metadata.device_ip
), ("super_device_id",
decode_metadata.super_device_id), ("rank_id", "1")]
if v is not None
}],
"server_id":
decode_metadata.server_id
}
rank_table["server_list"].append( # type: ignore[attr-defined]
prefill_server_device_info)
if decode_server_device_info is not None:
rank_table["server_list"].append( # type: ignore[attr-defined]
decode_server_device_info)
if self.soc_info == AscendSocVersion.A3:
# generate super_pod_list for rank table
super_pod_list = []
prefill_super_pod_info = {
"super_pod_id": prefill_metadata.super_pod_id,
"server_list": [{
"server_id": prefill_metadata.server_id
}],
}
if is_same_pod and not is_same_server:
prefill_super_pod_info[
"server_list"].append( # type: ignore[attr-defined]
{"server_id": decode_metadata.server_id})
super_pod_list.append(prefill_super_pod_info)
if not is_same_pod:
decode_super_pod_id = {
"super_pod_id": decode_metadata.super_pod_id,
"server_list": [{
"server_id": decode_metadata.server_id
}],
}
super_pod_list.append(decode_super_pod_id)
rank_table[
"super_pod_list"] = super_pod_list # type: ignore[assignment]
logger.info(
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
)
logger.info(f"rank table \n{rank_table}")
logger.info(f"comm name: {comm_name}")
logger.info(f"cluster rank info: {cluster_rank_info}")
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
json.dumps(rank_table))
while True:
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
if ret == llm_datadist.RegisterMemStatus.OK:
logger.info(
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
)
break
elif ret == llm_datadist.RegisterMemStatus.FAILED:
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
)
time.sleep(1)
logger.info("Checking query_register_mem_status again")
self.linked_cluster.update({remote_cluster_id: comm_id})
logger.info(f"cached linked cluster: {self.linked_cluster}")
logger.info(
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
)
return remote_cluster_id
def remove_remote_agent(self, cluster_id: int):
if cluster_id not in self.linked_cluster:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
)
comm_id = self.linked_cluster[cluster_id]
try:
self.llm_datadist.unlink(comm_id)
self.linked_cluster.pop(cluster_id)
except LLMException:
logger.error(
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
)
logger.info(
f"Successfully remove remote client with cluster id {cluster_id} !"
)
def connect_to_remote_agent(self, host: str, port: int) -> int:
url = f"tcp://{host}:{port}"
logger.debug(f"Querying metadata from url: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
logger.info("Try request remote metadata from socket......")
sock.send(msg_send)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder()
metadata = decoder.decode(metadata_bytes)
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
logger.info(f"recving metadata: {metadata}")
cluster_id = self.add_remote_agent(metadata)
return cluster_id
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
for port in ports:
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}"
)
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_ip: str,
remote_port: int,
remote_engine_id: str,
request_id: str,
remote_tp_size: str,
):
# if remote_ip not in self.linked_cluster:
tp_offset = self.tp_rank % int(remote_tp_size)
remote_cluster_id = self.connect_to_remote_agent(
remote_ip, remote_port + tp_offset)
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
return
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
logger.info(f"remote cluster id is: {remote_cluster_id}")
if self.use_mla:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key,
self.cache, # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
remote_ports = list(
range(remote_port + self.tp_rank,
remote_port + int(remote_tp_size), self.tp_size))
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
now = time.perf_counter()
with self.thread_lock:
while self.reqs_to_send:
req_id, expires = next(iter(self.reqs_to_send.items()))
if now < expires:
break
logger.warning(
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
)
if req_id in self.reqs_to_send:
self.finished_reqs.add(req_id)
del self.reqs_to_send[req_id]
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
return req_ids_to_ret, None
else:
return None, req_ids_to_ret
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@contextlib.contextmanager
def zmq_ctx(socket_type: Any,
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
try:
ctx = zmq.Context() # type: ignore[attr-defined]
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
socket.bind(addr)
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
socket.connect(addr)
else:
raise ValueError(f"Unexpected socket type: {socket_type}")
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)

View File

@@ -0,0 +1,449 @@
import array
import hashlib
import json
import os
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import cdiv, logger
from vllm.v1.core.sched.output import NewRequestData
@dataclass
class MooncakeEngineMetadata:
"""name of the LLM model"""
model_name: str
""" world size when running under a distributed setting """
world_size: int
""" worker id when running under a distributed setting """
worker_id: int
""" the format of kv tensors """
kv_dtype: torch.dtype
""" the shape of kv tensors """
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
kv_shape: tuple[int, int, int, int, int]
block_size: int = 128
""" whether use MLA"""
use_mla: bool = False
@dataclass(order=True)
class MooncakeEngineKey:
model_name: str
world_size: int
worker_id: int
chunk_hash: str
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}")
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerMooncakeEngineKey(
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
layer_id,
))
return keys
def to_dict(self):
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
return {
"__type__": "CacheEngineKey",
"model_name": self.model_name,
"world_size": self.world_size,
"worker_id": self.worker_id,
"chunk_hash": self.chunk_hash,
}
@staticmethod
def from_dict(d):
return MooncakeEngineKey(
model_name=d["model_name"],
world_size=d["world_size"],
worker_id=d["worker_id"],
chunk_hash=d["chunk_hash"],
)
@dataclass(order=True)
class LayerMooncakeEngineKey(MooncakeEngineKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
class ChunkedTokenDatabase():
def __init__(
self,
metadata: MooncakeEngineMetadata,
):
self.metadata = metadata
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return MooncakeEngineKey(
self.metadata.model_name,
self.metadata.world_size,
self.metadata.worker_id,
chunk_hash,
)
def _hash(
self,
tokens: Union[torch.Tensor, List[int]],
prefix_hash: str,
) -> str:
# TODO: change it to a more efficient hash function
if isinstance(tokens, torch.Tensor):
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
elif isinstance(tokens, list):
tokens_bytes = array.array("I", tokens).tobytes()
return hashlib.sha256(prefix_hash.encode("ascii") +
tokens_bytes).hexdigest()
def _chunk_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
) -> Iterable[Union[torch.Tensor, List[int]]]:
"""
Chunk the tokens into chunks of size self.metadata.block_size.
:param tokens: the input tokens, with shape [seq_len]
device: the target device after chunking
:return: a generator of chunks of tokens, each with
shape [metadata.block_size]
"""
for i in range(0, len(tokens), self.metadata.block_size):
yield tokens[i:i + self.metadata.block_size]
def _prefix_hash(
self,
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
) -> Iterable[str]:
prefix_hash = ''
for token_chunk in token_chunks:
prefix_hash = self._hash(token_chunk, prefix_hash)
yield prefix_hash
def process_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
mask: Optional[torch.Tensor] = None,
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
num_falses = mask.numel() - mask.long().sum().item()
else:
num_falses = 0
if num_falses % self.metadata.block_size != 0:
raise ValueError(
"The number of Falses in the mask is not a multiple of the chunk size."
)
total_len = len(tokens)
token_chunks = self._chunk_tokens(tokens)
prefix_hashes = self._prefix_hash(token_chunks)
start_idx = 0
for chunk_id, hash_val in enumerate(prefix_hashes):
start_idx = chunk_id * self.metadata.block_size
end_idx = min(start_idx + self.metadata.block_size, total_len)
if start_idx < num_falses:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in mooncake
mooncake_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
@dataclass
class SaveSpec:
# Skip already saved tokens
skip_leading_tokens: int
# Whether the scheduler allow us to save the tokens
can_save: bool
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_ids: list[int]
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
# list[list[int]]
# Need to check the type of request.block_ids
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_ids.extend(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_ids: torch.Tensor
block_ids: list[int]
# # Slot mapping if exchange for block_id
# slot_mapping: torch.Tensor
# Skip save or not
save_spec: Optional[SaveSpec] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_ids = tracker.token_ids
input_token_len = len(input_token_ids)
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
skip_leading_tokens = tracker.num_saved_tokens
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
# Calculate the token ids and slot mappings for load and save
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.mooncake_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_ids=token_ids,
block_ids=tracker.allocated_block_ids,
save_spec=save_spec,
load_spec=load_spec,
is_last_chunk=is_last_chunk,
)
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerMooncakeEngineKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
use_ascend_direct: bool
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get("global_segment_size", 3355443200),
local_buffer_size=config.get("local_buffer_size", 1073741824),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)

View File

@@ -0,0 +1,293 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
import torch
from vllm.utils import logger
from vllm_npu.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
class KVTransferThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.tp_rank = tp_rank
self.tp_size = tp_size
self.m_store = m_store
self.ready_event = ready_event
self.kv_caches_base_addr = local_kv_caches_base_addr
self.block_len = block_len
self.token_database = token_database
self.block_size = block_size
self.done_task_lock = threading.Lock()
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def add_request(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
is_last_chunk: Optional[bool] = None,
current_event: Optional[torch.npu.Event] = None,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"tokens": tokens,
"block_ids": block_ids,
"mask": mask,
"is_last_chunk": is_last_chunk,
"current_event": current_event,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
pass
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheSendingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
current_event = req_meta["current_event"]
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
if key_list:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
if current_event is not None:
current_event.synchronize()
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
if current_event is not None:
current_event.synchronize()
self.m_store.put(key, addr, size)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.get(key, addr, size)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
num_layers: int):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.put(key, addr, size)
if req_meta.layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
get_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.get(key, addr, size)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -0,0 +1,639 @@
# Standard
import math
import threading
import time
from typing import Generator, List, Optional, Union
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.utils import get_kv_cache_torch_dtype, logger
from vllm_npu.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
MooncakeEngineMetadata)
from vllm_npu.distributed.mooncake.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_npu.distributed.mooncake.mooncake_store import Mooncakestore
class MooncakeEngine:
#The main class for the cache engine.
def __init__(
self,
vllm_config: VllmConfig,
use_layerwize: bool,
):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"register_buffer", False)
self.block_size = vllm_config.cache_config.block_size
self.current_layer = 0
# self.use_mla = first_kv_cache_tuple[0].size(
# -1) != first_kv_cache_tuple[1].size(-1)
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
num_kv_head = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
kv_dtype = get_kv_cache_torch_dtype(
vllm_config.cache_config.cache_dtype, model_config.dtype)
self.hidden_dim_size = num_kv_head * head_size
if self.use_mla:
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
else:
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
head_size)
self.metadata = MooncakeEngineMetadata(
model_config.model,
parallel_config.world_size,
parallel_config.rank,
kv_dtype,
kv_shape,
self.block_size,
self.use_mla,
)
self.token_database = ChunkedTokenDatabase(self.metadata)
self.m_store = Mooncakestore(parallel_config)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
self.kv_caches_base_addr = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.register_buffer:
region_len = self.num_blocks * self.block_len[i % 2]
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.register_buffer:
region_len = self.num_blocks * self.block_len[0]
self._register(base_addr, region_len)
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending,
self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
try:
self.m_store.register_buffer(ptr, length)
except Exception as e:
raise RuntimeError(
f"Mooncake memory registration failed. Error is: {e}")
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
self.current_layer = 0
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
continue
tokens = request.token_ids
req_id = request.req_id
if (load_spec.mooncake_cached_tokens % self.block_size
!= 0) and (load_spec.mooncake_cached_tokens
== tokens.shape[0] - 1):
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
else:
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
masked_token_count = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
token_mask = torch.ones_like(tokens, dtype=torch.bool)
token_mask[:masked_token_count] = False
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
tokens,
request.block_ids,
token_mask,
)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
tokens,
request.block_ids,
token_mask,
)
else:
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, block_id = self.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list,
blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, _ = self.prepare_value(
start, end, request.block_ids)
self.m_store.get(key, addr, size)
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def wait_for_layer_load(self) -> None:
"""MooncakeConnector does not do layerwise saving."""
for layerwise_retriever in self.layerwise_retrievers:
ret_token_mask = next(layerwise_retriever)
if self.current_layer == self.num_layers - 1:
assert ret_token_mask is not None
num_retrieved_tokens = ret_token_mask.sum().item()
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: MooncakeConnectorMetadata) -> None:
"""MooncakeConnector does not save explicitly."""
if self.current_layer == 0:
self.layerwise_storers = []
current_event = None
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
token_ids = request.token_ids
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_ids,
mask=store_mask,
block_ids=request.block_ids,
)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
next(layerwise_storer)
except Exception:
raise
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
"""MooncakeConnector does not save explicitly."""
current_event = None
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
token_ids = request.token_ids
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_ids,
request.block_ids,
store_mask,
request.is_last_chunk,
current_event,
)
def retrieve_layer(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the KV transfer which
will be passed into the npu_transfer.
return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
if mask is not None:
num_required_tokens = torch.sum(mask).item()
else:
num_required_tokens = len(tokens)
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
starts = []
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
tokens, mask):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer)
ret_mask[start:end] = True
if keys:
# Transpose the keys into layer major format
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False
yield None
else:
# If no cache are found, we still need to yield to avoid
# `StopIteration`
for layer_id in range(self.num_layers):
yield None
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {len(tokens)} tokens")
yield ret_mask
def store_layer(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
return: A generator that yields None. In the first iteration, the
generator allocates the memory objects for all layers and moves
the KV cache of the first layer from GPU to CPU. In the next
iterations, it moves the KV cache of layer i from GPU to the memory
objects (on CPU) and puts the memory objects of layer i-1 to the
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
if mask is not None:
num_stored_tokens = torch.sum(mask).item()
else:
num_stored_tokens = len(tokens)
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num]
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.load_async else set())
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
self.tp_rank)
return done_sending, done_recving
def wait_layer_transfer_finish(self):
time.sleep(10)
pass
def lookup(
self,
tokens: Union[torch.Tensor, List[int]],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
keys.append(key.to_string())
starts.append(start)
res = self.m_store.batch_exists(
keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def lookup_scheduler(
self,
tokens: Union[torch.Tensor, List[int]],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
keys.append(key.to_string())
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, self.tp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@0", f"@{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.batch_exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
multi_tp_values = [
res[i * num_block:(i + 1) *
num_block] # type: ignore[index]
for i in range(self.tp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
except ValueError:
return -1
def close(self) -> None:
"""Close the cache engine and free all the resources"""
self.m_store.close()

View File

@@ -0,0 +1,126 @@
# Standard
import os
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.utils import get_ip, logger
from vllm_npu.distributed.mooncake.config_data import MooncakeEngineKey
from vllm_npu.distributed.mooncake.transfer_engine import get_global_te
from .config_data import MooncakeStoreConfig
METADATA_BYTES_LEN = 24
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
class Mooncakestore():
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
tp_rank = get_tensor_model_parallel_rank()
tp_size = parallel_config.tensor_parallel_size
dp_rank = parallel_config.data_parallel_rank_local
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
if not all_device_ids:
device_ids_list = list(
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
else:
device_ids_list = list(map(int, all_device_ids.split(',')))
assert len(device_ids_list) > tp_rank
device_id = device_ids_list[tp_rank]
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
":npu_" + str(device_id)
ret = self.store.setup(local_hostname, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address)
else:
local_hostname = get_ip()
transfer_engine = get_global_te(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def exists(self, key: MooncakeEngineKey) -> bool:
return self.store.is_exist(key.to_string()) == 1
def batch_exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def register_buffer(self, ptr, length):
return self.store.register_buffer(ptr, length)
def get_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
res = self.store.batch_get_into_multi_buffers(
keys, addrs, sizes, True)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}. {e}")
def put_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(
keys, addrs, sizes, config)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
expect_res = sum(size)
key_str = key.to_string()
try:
res = self.store.batch_get_into_ascend(key_str, addr, size)
if res[0] != expect_res:
logger.error(f"Failed to get key: [{key_str}] .")
except Exception:
logger.error(f"Failed to get key: [{key_str}] .")
return res
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
key_str = key.to_string()
try:
ret = self.store.batch_put_from_ascend(key_str, addr, size)
if ret[0] != 0:
logger.error(f"Failed to put key {key_str}.")
except Exception:
logger.error(f"Failed to put key {key_str}.")
return ret
def close(self):
self.store.close()
logger.info("Closed the mooncake store connection")

View File

@@ -0,0 +1,494 @@
import threading
from typing import Any, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.utils import logger, make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm_npu.distributed.mooncake.config_data import (
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
from vllm_npu.distributed.mooncake.mooncake_engine import MooncakeEngine
class MooncakeConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
vllm_config, self.use_layerwise)
else:
self.connector_worker = MooncakeEngine(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0 and self.kv_role != "kv_consumer":
self.lookup_server = MooncakeLookupServer(
self.connector_worker, vllm_config, self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._get_connector_metadata(),
MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeStoreConnector does not do layerwise saving."""
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""MooncakeStoreConnector does not save explicitly."""
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
"""MooncakeStoreConnector does not save explicitly."""
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
self.connector_worker.wait_layer_transfer_finish()
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:
sended_and_finished.add(item)
self.sended_but_unfinished_reqs.remove(item)
for item in done_sending:
if item in meta.unfinished_request_ids:
self.sended_but_unfinished_reqs.add(item)
else:
sended_and_finished.add(item)
return sended_and_finished, done_recving
def get_zmq_rpc_path_mooncake(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
"mooncake_rpc_port", 0)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
class MooncakeStoreConnectorV1Scheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.client = MooncakeLookupClient(
vllm_config) if self.kv_role != "kv_consumer" else None
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
# request_id -> (vllm cached tokes, mooncake cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self._block_size = vllm_config.cache_config.block_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True))
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set()
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Check for external KV cache hit.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
return 0, False
if self._discard_partial_chunks:
token_block_end = len(request.prompt_token_ids
) // self._block_size * self._block_size
token_ids = torch.tensor(
request.prompt_token_ids[:token_block_end])
else:
token_ids = torch.tensor(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup( # type: ignore[union-attr]
token_ids)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.info(
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
need_to_allocate,
)
if need_to_allocate <= 0:
return 0, False
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
mooncake_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, self.load_async
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after temporary buffer alloc.
For SharedStorageConnector, update _request_needs_load
if the CacheManager this allocated blocks for us.
"""
local_block_ids = []
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request,
local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return
return
if num_external_tokens == 0:
# No need to load anything
self.load_specs[request.request_id].can_load = False
return
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].mooncake_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
except the `kv_connector_metadata` field.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = self.kv_role == "kv_consumer"
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tracker = RequestTracker.from_new_request(
request, num_tokens_to_compute)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
cached_reqs = scheduler_output.scheduled_cached_reqs
if isinstance(cached_reqs, list) and not force_skip_save:
for i, req in enumerate(cached_reqs):
request_tracker = self._request_trackers[req.req_id]
request_tracker.update(req.new_token_ids, req.new_block_ids)
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(req.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
elif not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = len(request_tracker.token_ids)
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if len(request_tracker.token_ids) > len(
request.prompt_token_ids):
continue
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
request_ids = [
req.req_id for req in scheduler_output.scheduled_new_reqs
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.mooncake_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=None,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer":
return False, None
tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0:
return False, None
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(block_ids), request.request_id)
return delay_free_blocks, None
class MooncakeLookupClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REQ, # type: ignore[attr-defined]
bind=False,
)
def lookup(self, token_ids: torch.Tensor) -> int:
request = self.encoder.encode(token_ids)
self.socket.send_multipart(request, copy=False)
resp = self.socket.recv()
result = int.from_bytes(resp, "big")
return result
def close(self):
self.socket.close(linger=0)
class MooncakeLookupServer:
def __init__(
self,
mooncake_engine: MooncakeEngine,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.mooncake_engine = mooncake_engine
self.running = True
def process_request():
while self.running:
frames = self.socket.recv_multipart(copy=False)
token_ids = self.decoder.decode(frames)
result = self.mooncake_engine.lookup_scheduler(
token_ids, use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!

View File

@@ -0,0 +1,38 @@
import ipaddress
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
_global_te = None
_global_te_lock = threading.Lock()
def get_global_te(hostname: str, device_name: Optional[str]):
try:
ip = ipaddress.ip_address(hostname)
if isinstance(ip, ipaddress.IPv6Address):
raise RuntimeError(
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
)
except ValueError:
pass
global _global_te
if _global_te is None:
with _global_te_lock:
# Double-Checked Locking
if _global_te is None:
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = transfer_engine.initialize(hostname,
"P2PHANDSHAKE",
"ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
_global_te = transfer_engine
return _global_te

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,196 @@
from typing import Optional
import torch
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
import vllm_npu.envs as envs_ascend
from vllm_npu.ascend_config import get_ascend_config
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def get_otp_group() -> GroupCoordinator:
assert _OTP is not None, (
"output tensor parallel group is not initialized")
return _OTP
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
def get_p_tp_group() -> GroupCoordinator:
assert _P_TP is not None, (
"distributed prefill tensor parallel group is not initialized")
return _P_TP
def model_parallel_initialized():
return (_MC2 is not None)
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio
global _P_TP
assert _P_TP is None, (
"distributed prefill tensor parallel group is already initialized")
prefill_tensor_model_parallel_size = pd_tp_ratio
# divide alltoall groups
if pd_head_ratio > 1 and get_current_vllm_config(
).kv_transfer_config.is_kv_producer:
num_head_replica = get_ascend_config().num_head_replica
remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio
if num_head_replica <= 1:
group_ranks = all_ranks.view(
-1, prefill_tensor_model_parallel_size).unbind(0)
else:
group_ranks = all_ranks.clone().view(
parallel_config.data_parallel_size, -1,
num_head_replica) # [DP_size, num_head, num_head_replica]
group_ranks = group_ranks.permute(0, 2, 1)
group_ranks = group_ranks.reshape(
-1,
group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
group_ranks = group_ranks.unsqueeze(-1).view(
parallel_config.data_parallel_size, num_head_replica, -1,
alltoall_group_size
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.reshape(-1,
alltoall_group_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
local_rank = get_world_group().local_rank
num = next(
(i for i, ranks in enumerate(group_ranks) if local_rank in ranks),
None)
_P_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=f"p_tp_{num}")
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_MC2 = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mc2")
if envs_ascend.vllm_npu_ENABLE_MLP_OPTIMIZE:
global _MLP_TP
assert _MLP_TP is None, (
"mlp tensor model parallel group is already initialized")
mlp_tp = parallel_config.data_parallel_size
all_ranks_mlp_head = torch.arange(world_size).reshape(
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_MLP_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mlp_tp")
# If oproj tensor parallel size is set, we will create a group for it.
otp_size = get_ascend_config().oproj_tensor_parallel_size
if otp_size is not None:
group_ranks = []
global _OTP
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
for i in range(num_oproj_tensor_parallel_groups):
ranks = list(range(i * otp_size, (i + 1) * otp_size))
group_ranks.append(ranks)
_OTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="otp")
lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
group_ranks = []
global _LMTP
num_lmhead_tensor_parallel_groups: int = (world_size //
lmhead_tensor_parallel_size)
for i in range(num_lmhead_tensor_parallel_groups):
ranks = list(
range(i * lmhead_tensor_parallel_size,
(i + 1) * lmhead_tensor_parallel_size))
group_ranks.append(ranks)
_LMTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="lmheadtp")
def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_mlp_tp_group().world_size
def get_mlp_tensor_model_parallel_rank():
"""Return world size for the tensor model parallel group."""
return get_mlp_tp_group().rank_in_group
def destroy_ascend_model_parallel():
global _MC2
if _MC2:
_MC2.destroy()
_MC2 = None
global _MLP_TP
if _MLP_TP:
_MLP_TP.destroy()
_MLP_TP = None
global _LMTP
if _LMTP:
_LMTP.destroy()
_LMTP = None
global _OTP
if _OTP:
_OTP.destroy()
_OTP = None
global _P_TP
if _P_TP:
_P_TP.destroy()
_P_TP = None

View File

@@ -0,0 +1,61 @@
import os
import torch
import torch.distributed as dist
from vllm_npu.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
value: torch.TensorType):
if pd_tp_ratio <= 1:
return None, None
elif key is None or value is None:
raise ValueError("key or value is None")
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
return k_output, v_output
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor,
input_tensor,
group=get_p_tp_group().device_group)
input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0
return result
def rearrange_output(base_output: torch.Tensor, cut_num: int,
num_kv_heads: int):
size_0 = base_output.size(0)
if size_0 % cut_num != 0:
raise ValueError(
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
)
chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1)
return transposed.contiguous().view(size_0, num_kv_heads, -1)
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
def get_transfer_timeout_value():
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
if len(ascend_transfer_timeout) > 0:
return int(ascend_transfer_timeout)
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
'20')) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
'7')) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
3000)