mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 11:42:30 +00:00
994 lines
46 KiB
Python
994 lines
46 KiB
Python
import contextlib
|
|
import copy
|
|
import json
|
|
import math
|
|
import os
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, Callable, Optional, Tuple
|
|
|
|
import llm_datadist # type: ignore
|
|
import msgspec
|
|
import torch
|
|
import zmq
|
|
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
|
LLMException, LLMRole)
|
|
from vllm import envs
|
|
from vllm.config import KVTransferConfig, VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.utils import get_ip, logger
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.request import Request, RequestStatus
|
|
|
|
import vllm_npu.envs as envs_ascend
|
|
from vllm_npu.distributed.utils import get_transfer_timeout_value
|
|
from vllm_npu.utils import AscendSocVersion, get_ascend_soc_version
|
|
|
|
TORCH_DTYPE_TO_NPU_DTYPE = {
|
|
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
|
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
|
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
|
torch.float: llm_datadist.DataType.DT_FLOAT,
|
|
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
|
torch.int8: llm_datadist.DataType.DT_INT8,
|
|
torch.int64: llm_datadist.DataType.DT_INT64,
|
|
torch.int32: llm_datadist.DataType.DT_INT32
|
|
}
|
|
|
|
|
|
class LLMDataDistCMgrEvent(Enum):
|
|
ReqForMetadata = 0
|
|
ReqForFinished = 1
|
|
|
|
|
|
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
|
|
super_pod_id: str
|
|
server_id: str
|
|
device_id: str
|
|
device_ip: str
|
|
super_device_id: str
|
|
cluster_id: int
|
|
|
|
|
|
@dataclass
|
|
class ReqMeta:
|
|
local_block_ids: list[int]
|
|
remote_block_ids: list[int]
|
|
remote_host: str
|
|
remote_port: str
|
|
engine_id: str
|
|
remote_tp_size: str
|
|
|
|
|
|
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
|
|
|
def __init__(self):
|
|
self.requests: dict[str, ReqMeta] = {}
|
|
|
|
def add_new_req(self, request_id: str, local_block_ids: list[int],
|
|
kv_transfer_params: dict[str, Any]):
|
|
self.requests[request_id] = ReqMeta(
|
|
local_block_ids=local_block_ids,
|
|
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
|
engine_id=kv_transfer_params["remote_engine_id"],
|
|
remote_host=kv_transfer_params["remote_host"],
|
|
remote_port=kv_transfer_params["remote_port"],
|
|
remote_tp_size=kv_transfer_params["remote_tp_size"],
|
|
)
|
|
|
|
|
|
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
|
|
|
|
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
|
assert vllm_config.kv_transfer_config is not None
|
|
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
|
if role == KVConnectorRole.SCHEDULER:
|
|
self.connector_scheduler: Optional[
|
|
LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler(
|
|
vllm_config, self.engine_id)
|
|
elif role == KVConnectorRole.WORKER:
|
|
self.connector_scheduler = None
|
|
self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config)
|
|
|
|
############################################################
|
|
# Scheduler Side Methods
|
|
############################################################
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request",
|
|
num_computed_tokens: int) -> tuple[int, bool]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
|
request, num_computed_tokens)
|
|
|
|
def update_state_after_alloc(self, request: "Request",
|
|
blocks: "KVCacheBlocks",
|
|
num_external_tokens: int):
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.update_state_after_alloc(
|
|
request, blocks, num_external_tokens)
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.request_finished(request, block_ids)
|
|
|
|
############################################################
|
|
# Worker Side Methods
|
|
############################################################
|
|
def register_kv_caches(
|
|
self,
|
|
kv_caches: dict[
|
|
str, # type: ignore[override]
|
|
Tuple[torch.Tensor]]):
|
|
assert self.connector_worker is not None
|
|
self.connector_worker.register_kv_caches(kv_caches)
|
|
|
|
def get_finished(
|
|
self, finished_req_ids: set[str]
|
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
|
"""Get the finished recving and sending requests."""
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.get_finished(finished_req_ids)
|
|
|
|
def start_load_kv(self, forward_context: "ForwardContext",
|
|
**kwargs) -> None:
|
|
assert self.connector_worker is not None
|
|
assert isinstance(self._connector_metadata,
|
|
LLMDataDistCMgrConnectorMetadata)
|
|
self.connector_worker.start_load_kv(self._connector_metadata)
|
|
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
"""LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager."""
|
|
pass
|
|
|
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
|
attn_metadata, **kwargs) -> None:
|
|
"""LLMDataDistCMgrConnector does not save explicitly."""
|
|
pass
|
|
|
|
def wait_for_save(self):
|
|
"""LLMDataDistCMgrConnector does not save explicitly."""
|
|
pass
|
|
|
|
|
|
class LLMDataDistCMgrConnectorScheduler():
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
|
|
self.vllm_config = vllm_config
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.engine_id = engine_id
|
|
self.local_ip = get_ip()
|
|
# Can not retrieve the parallel config since it is not initialized.
|
|
self.local_dp_rank = None
|
|
self.tp_size = None
|
|
if vllm_config.parallel_config.data_parallel_external_lb:
|
|
dp_rank_local = vllm_config.parallel_config.data_parallel_rank
|
|
else:
|
|
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
|
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
|
|
self.port = dp_rank_local * tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.vllm_npu_LLMDD_RPC_PORT
|
|
|
|
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
|
self._reqs_need_send: dict[str, float] = {}
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request",
|
|
num_computed_tokens: int) -> tuple[int, bool]:
|
|
"""
|
|
For remote prefill, pull all prompt blocks from remote
|
|
asynchronously relative to engine execution.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
num_computed_tokens (int): the number of locally
|
|
computed tokens for this request
|
|
Returns:
|
|
* the number of tokens that can be loaded from the
|
|
external KV cache beyond what is already computed.
|
|
* true if the external KV cache tokens will be loaded
|
|
asynchronously (between scheduler steps).
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}"
|
|
)
|
|
|
|
if params is not None and params.get("do_remote_prefill"):
|
|
# Remote prefill: get all prompt blocks from remote.
|
|
assert num_computed_tokens % self.block_size == 0
|
|
# Note: We use the full token count as transmit data here.
|
|
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
|
return count, count > 0
|
|
|
|
# No remote prefill for this request.
|
|
return 0, False
|
|
|
|
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
|
num_externel_tokens: int):
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}"
|
|
)
|
|
if params is not None and params.get("do_remote_prefill"):
|
|
if params.get("remote_block_ids"):
|
|
if all(p in params for p in ("remote_engine_id", "remote_host",
|
|
"remote_port", "remote_tp_size")):
|
|
self._reqs_need_recv[request.request_id] = (
|
|
request, blocks.get_unhashed_block_ids())
|
|
else:
|
|
logger.warning("" \
|
|
f"Invalid KVTransferParams {params}, This request will be discard")
|
|
else:
|
|
assert num_externel_tokens == 0
|
|
params["do_remote_prefill"] = False
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
meta = LLMDataDistCMgrConnectorMetadata()
|
|
|
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
|
assert req.kv_transfer_params is not None
|
|
meta.add_new_req(request_id=req_id,
|
|
local_block_ids=block_ids,
|
|
kv_transfer_params=req.kv_transfer_params)
|
|
|
|
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
|
|
|
|
# Clear the list once workers start the transfers
|
|
self._reqs_need_recv.clear()
|
|
self._reqs_need_send.clear()
|
|
|
|
return meta
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"LLMDataDistCMgrConnector request_finished, request_status=%s, "
|
|
"kv_transfer_params=%s", request.status, params)
|
|
|
|
if (params is None or not params.get("do_remote_decode")
|
|
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
|
return False, None
|
|
|
|
# note: NIXL transfer the full block only, but I don't see any reason to do that, so here
|
|
# we just transfer any data that computed from prefill node
|
|
# note: there might be some issue on this, check it if there is any unexpected result
|
|
computed_block_ids = block_ids
|
|
delay_free_blocks = len(computed_block_ids) > 0
|
|
if delay_free_blocks:
|
|
logger.info("Delaying free of %d blocks for request %s",
|
|
len(computed_block_ids), request.request_id)
|
|
# Prefill request on remote. It will be read from D upon completion
|
|
self._reqs_need_send[request.request_id] = time.perf_counter(
|
|
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
|
return delay_free_blocks, dict(
|
|
do_remote_prefill=True,
|
|
do_remote_decode=False,
|
|
remote_block_ids=computed_block_ids,
|
|
remote_engine_id=self.engine_id,
|
|
remote_host=self.local_ip,
|
|
remote_port=self.port,
|
|
remote_tp_size=str(
|
|
self.vllm_config.parallel_config.tensor_parallel_size),
|
|
)
|
|
|
|
|
|
class LLMDataDistCMgrConnectorWorker():
|
|
"""
|
|
Implementation of Worker side methods
|
|
"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig):
|
|
assert vllm_config.kv_transfer_config is not None
|
|
logger.info("Initialize the LLMDataDistCMgrConnectorWorker")
|
|
# we assume the local node only contains dp and tp, and tp will not communicate inter-node.
|
|
# for any scenario beyond this scope, the functionality of this connector is not guaranteed.
|
|
self.local_rank_on_node = get_world_group().rank % (
|
|
vllm_config.parallel_config.data_parallel_size_local *
|
|
vllm_config.parallel_config.tensor_parallel_size)
|
|
self.local_rank = get_world_group().local_rank
|
|
if vllm_config.parallel_config.data_parallel_external_lb:
|
|
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
else:
|
|
self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
|
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
self.tp_rank = get_tp_group().rank_in_group
|
|
self.rank = get_world_group().rank
|
|
self.local_ip = get_ip()
|
|
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
|
|
self.local_agent_metadata: Optional[
|
|
LLMDataDistCMgrAgentMetadata] = None
|
|
self.vllm_config = vllm_config
|
|
self.executor = ThreadPoolExecutor(1)
|
|
self.thread_lock = threading.Lock()
|
|
|
|
self.llm_datadist_role = None
|
|
self.llm_datadist_remote_role = None
|
|
if self.kv_transfer_config.kv_role == "kv_producer":
|
|
self.llm_datadist_role = LLMRole.PROMPT
|
|
self.llm_datadist_remote_role = LLMRole.DECODER
|
|
elif self.kv_transfer_config.kv_role == "kv_consumer":
|
|
self.llm_datadist_role = LLMRole.DECODER
|
|
self.llm_datadist_remote_role = LLMRole.PROMPT
|
|
else:
|
|
raise RuntimeError(
|
|
f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}"
|
|
)
|
|
|
|
# linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"}
|
|
self.linked_cluster: dict[Any, Any] = {}
|
|
self.prefill_device_list: list[tuple[int, int]] = []
|
|
self.decode_device_list: list[tuple[int, int]] = []
|
|
global_rank_table = self.read_offline_rank_table()
|
|
self.local_agent_metadata = self.read_agent_metadata(global_rank_table)
|
|
self.llm_datadist = LLMDataDist(self.llm_datadist_role,
|
|
self.local_agent_metadata.cluster_id)
|
|
self.init_llm_datadist()
|
|
self.finished_reqs: set[str] = set()
|
|
self.soc_info = get_ascend_soc_version()
|
|
# Set hccl deterministic for model execute
|
|
os.environ["HCCL_DETERMINISTIC"] = "true"
|
|
self.done_receiving_counts: defaultdict[str,
|
|
set[int]] = defaultdict(set)
|
|
self.reqs_to_send: dict[str, float] = {}
|
|
|
|
def listen_for_agent_metadata_req(self, event: threading.Event):
|
|
assert self.local_agent_metadata is not None
|
|
port = envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.vllm_npu_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
|
|
url = f"tcp://{envs_ascend.vllm_npu_LLMDD_RPC_IP}:{port}"
|
|
msg_encoder = msgspec.msgpack.Encoder()
|
|
msg_decoder = msgspec.msgpack.Decoder()
|
|
msg_to_send = msg_encoder.encode(self.local_agent_metadata)
|
|
logger.debug(f"Start to listen to address: {url}")
|
|
logger.debug(
|
|
f"The local agent metadata have {len(msg_to_send)} bytes here")
|
|
logger.info(
|
|
f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers"
|
|
)
|
|
with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined]
|
|
event.set()
|
|
while True:
|
|
identity, _, msg = sock.recv_multipart()
|
|
event_msg, decode_msg = msg_decoder.decode(msg)
|
|
event_msg = LLMDataDistCMgrEvent(event_msg)
|
|
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
|
|
if "cluster_id" in decode_msg:
|
|
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
|
|
logger.info(
|
|
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
|
|
)
|
|
sock.send_multipart((identity, b"", msg_to_send))
|
|
self.add_remote_agent(decode_msg)
|
|
else:
|
|
logger.warning(
|
|
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
|
|
)
|
|
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
|
finished_req_id = decode_msg[0]
|
|
with self.thread_lock:
|
|
logger.debug(
|
|
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
|
)
|
|
if finished_req_id in self.reqs_to_send:
|
|
self.finished_reqs.add(finished_req_id)
|
|
del self.reqs_to_send[finished_req_id]
|
|
sock.send_multipart(
|
|
(identity, b"", b"receiving decode finished"))
|
|
else:
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
|
)
|
|
|
|
def init_llm_datadist(self):
|
|
assert self.local_agent_metadata is not None
|
|
llm_config = LLMConfig()
|
|
llm_config.device_id = self.local_rank
|
|
llm_config.sync_kv_timeout = get_transfer_timeout_value()
|
|
llm_config.enable_switch_role = True
|
|
llm_config.enable_cache_manager = True
|
|
llm_config.enable_remote_cache_accessible = True
|
|
llm_config_options = llm_config.generate_options()
|
|
self.llm_datadist.init(llm_config_options)
|
|
self.cache_manager = self.llm_datadist.cache_manager
|
|
logger.info(
|
|
f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}"
|
|
)
|
|
|
|
def read_offline_rank_table(self):
|
|
assert (
|
|
envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
|
), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH"
|
|
rank_table_path = envs_ascend.DISAGGREGATED_PREFILL_RANK_TABLE_PATH
|
|
with open(rank_table_path, "r", encoding="utf-8") as f:
|
|
global_rank_table = json.load(f)
|
|
decode_device_list = global_rank_table["decode_device_list"]
|
|
for decode_device in decode_device_list:
|
|
server_id = decode_device["server_id"]
|
|
device_id = decode_device["device_id"]
|
|
self.decode_device_list.append((server_id, device_id))
|
|
prefill_device_list = global_rank_table["prefill_device_list"]
|
|
for prefill_device in prefill_device_list:
|
|
server_id = prefill_device["server_id"]
|
|
device_id = prefill_device["device_id"]
|
|
self.prefill_device_list.append((server_id, device_id))
|
|
|
|
# global_rank_table = json.dumps(global_rank_table)
|
|
return global_rank_table
|
|
|
|
@staticmethod
|
|
def _get_visible_devices() -> Callable[[str], bool]:
|
|
"""
|
|
Return a test function that check if the given device ID is visible.
|
|
i.e. ASCEND_RT_VISIBLE_DEVICES is not set or contains the device_id.
|
|
"""
|
|
visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
|
if not visible_devices:
|
|
return lambda device_id: True
|
|
visible_device_list = visible_devices.split(",")
|
|
return lambda device_id: device_id in visible_device_list
|
|
|
|
def read_agent_metadata(self, global_rank_table):
|
|
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
|
|
devices_type_list = []
|
|
agent_metadata = None
|
|
if self.llm_datadist_role == LLMRole.PROMPT:
|
|
devices_type_list.append("prefill_device_list")
|
|
elif self.llm_datadist_role == LLMRole.DECODER:
|
|
devices_type_list.append("decode_device_list")
|
|
else:
|
|
devices_type_list.append("prefill_device_list")
|
|
devices_type_list.append("decode_device_list")
|
|
for device_type in devices_type_list:
|
|
device_list = global_rank_table[device_type]
|
|
device_list = [
|
|
d for d in device_list if d.get("server_id") == self.local_ip
|
|
and device_filter(d.get("device_id", ""))
|
|
]
|
|
if len(device_list) <= self.tp_rank:
|
|
continue
|
|
device_info = device_list[self.tp_rank]
|
|
super_pod_id_ = device_info.get("super_pod_id", None)
|
|
server_id_ = device_info["server_id"]
|
|
device_id_ = device_info["device_id"]
|
|
device_ip_ = device_info["device_ip"]
|
|
super_device_id_ = device_info.get("super_device_id", None)
|
|
cluster_id_ = int(device_info["cluster_id"])
|
|
agent_metadata = LLMDataDistCMgrAgentMetadata(
|
|
super_pod_id=super_pod_id_,
|
|
server_id=server_id_,
|
|
device_id=device_id_,
|
|
device_ip=device_ip_,
|
|
super_device_id=super_device_id_,
|
|
cluster_id=cluster_id_,
|
|
)
|
|
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
|
|
return agent_metadata
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
|
|
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
|
first_kv_cache = first_kv_cache_tuple[0]
|
|
assert len(first_kv_cache_tuple) > 1
|
|
assert self.local_agent_metadata is not None
|
|
kv_cache_dtype = first_kv_cache.dtype
|
|
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
|
first_kv_cache_tuple) == 2
|
|
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
|
|
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
|
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
|
# MHA case. [2 (k and v), num_blocks, ...]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
block_rank = 3 # [block_size, latent_dim]
|
|
block_shape = first_kv_cache.shape[-block_rank:]
|
|
|
|
self.block_len = math.prod(block_shape)
|
|
self.cache_addr: list[int] = []
|
|
alignment = 2 * 1024 * 1024
|
|
if self.use_mla:
|
|
cache_k_normed_addr_list = []
|
|
cache_k_pe_addr_list = []
|
|
k_normed = None
|
|
k_pe = None
|
|
for cache_or_caches in kv_caches.values():
|
|
assert len(cache_or_caches) > 1
|
|
k_normed, k_pe = cache_or_caches[0], cache_or_caches[1]
|
|
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
|
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
|
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list)
|
|
|
|
cache_desc_k_normed = CacheDesc(
|
|
len(self.cache_addr[0]), [*k_normed.shape],
|
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
|
cache_desc_k_pe = CacheDesc(
|
|
len(self.cache_addr[1]), [*k_pe.shape],
|
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
|
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
|
self.local_agent_metadata.cluster_id),
|
|
model_id=0)
|
|
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
|
self.local_agent_metadata.cluster_id),
|
|
model_id=1)
|
|
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe)
|
|
self.cache_key = (cache_key_k_normed, cache_key_k_pe)
|
|
try:
|
|
cache_k_normed = self.cache_manager.register_blocks_cache(
|
|
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
|
cache_k_pe = self.cache_manager.register_blocks_cache(
|
|
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
|
self.cache = (cache_k_normed, cache_k_pe)
|
|
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
|
except (TypeError, ValueError):
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
|
)
|
|
elif self.use_sparse:
|
|
cache_k_normed_addr_list = []
|
|
cache_k_pe_addr_list = []
|
|
cache_k_idx_addr_list = []
|
|
k_normed = None
|
|
k_pe = None
|
|
k_idx = None
|
|
for cache_or_caches in kv_caches.values():
|
|
assert len(cache_or_caches) > 1
|
|
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
|
1], cache_or_caches[2]
|
|
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
|
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
|
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
|
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
|
cache_k_idx_addr_list)
|
|
|
|
cache_desc_k_normed = CacheDesc(
|
|
len(self.cache_addr[0]), [*k_normed.shape],
|
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
|
cache_desc_k_pe = CacheDesc(
|
|
len(self.cache_addr[1]), [*k_pe.shape],
|
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
|
cache_desc_k_idx = CacheDesc(
|
|
len(self.cache_addr[2]), [*k_idx.shape],
|
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
|
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
|
self.local_agent_metadata.cluster_id),
|
|
model_id=0)
|
|
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
|
self.local_agent_metadata.cluster_id),
|
|
model_id=1)
|
|
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
|
self.local_agent_metadata.cluster_id),
|
|
model_id=2)
|
|
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
|
cache_desc_k_idx)
|
|
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
|
cache_key_k_idx)
|
|
try:
|
|
cache_k_normed = self.cache_manager.register_blocks_cache(
|
|
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
|
cache_k_pe = self.cache_manager.register_blocks_cache(
|
|
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
|
cache_k_idx = self.cache_manager.register_blocks_cache(
|
|
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
|
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
|
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
|
except (TypeError, ValueError):
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
|
)
|
|
else:
|
|
for cache_or_caches in kv_caches.values():
|
|
for cache in cache_or_caches:
|
|
base_addr = cache.data_ptr()
|
|
assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
|
self.cache_addr.append(base_addr)
|
|
# register paged kv cache into the llm_cache manager
|
|
self.cache_desc = CacheDesc(
|
|
len(self.cache_addr), [*cache.shape],
|
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
|
self.cache_key = BlocksCacheKey(
|
|
cluster_id=int(self.local_agent_metadata.cluster_id))
|
|
logger.info(
|
|
f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}"
|
|
)
|
|
try:
|
|
self.cache = self.cache_manager.register_blocks_cache(
|
|
self.cache_desc, self.cache_addr, self.cache_key)
|
|
logger.info(
|
|
"LLMDataDistCMgrConnectorWorker: End of register Paged Cache."
|
|
)
|
|
except (TypeError, ValueError):
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
|
)
|
|
self.ready_event = threading.Event()
|
|
self.metadata_agent_listener_t = threading.Thread(
|
|
target=self.listen_for_agent_metadata_req,
|
|
args=(self.ready_event, ),
|
|
daemon=True,
|
|
name="metadata_agent_listener")
|
|
self.metadata_agent_listener_t.start()
|
|
self.ready_event.wait()
|
|
|
|
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
|
|
futures = []
|
|
for req_id, meta in metadata.requests.items():
|
|
logger.debug(f"Start to transmit {req_id}")
|
|
future = self.executor.submit(
|
|
self._read_blocks,
|
|
local_block_ids=meta.local_block_ids,
|
|
remote_block_ids=meta.remote_block_ids,
|
|
remote_ip=meta.remote_host,
|
|
remote_port=int(meta.remote_port),
|
|
remote_engine_id=meta.engine_id,
|
|
request_id=req_id,
|
|
remote_tp_size=meta.remote_tp_size,
|
|
)
|
|
futures.append(future)
|
|
|
|
def handle_exception(future):
|
|
if future.exception():
|
|
logger.error(f"KV transfer task failed: {future.exception()}")
|
|
|
|
for future in futures:
|
|
future.add_done_callback(handle_exception)
|
|
self.reqs_to_send.update(metadata.reqs_to_send)
|
|
|
|
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
|
assert self.local_agent_metadata is not None
|
|
remote_cluster_id = metadata.cluster_id
|
|
if remote_cluster_id in self.linked_cluster:
|
|
logger.debug(
|
|
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
|
|
)
|
|
return remote_cluster_id
|
|
remote_super_pod_id = metadata.super_pod_id
|
|
remote_server_id = metadata.server_id
|
|
is_same_server = remote_server_id == self.local_agent_metadata.server_id
|
|
is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id
|
|
if self.llm_datadist_role == LLMRole.PROMPT:
|
|
prefill_metadata = self.local_agent_metadata
|
|
decode_metadata = metadata
|
|
else:
|
|
prefill_metadata = metadata
|
|
decode_metadata = self.local_agent_metadata
|
|
comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}"
|
|
cluster_rank_info = {
|
|
prefill_metadata.cluster_id: 0,
|
|
decode_metadata.cluster_id: 1
|
|
}
|
|
rank_table = {}
|
|
rank_table["version"] = "1.2"
|
|
rank_table["server_count"] = "1" if is_same_server else "2"
|
|
rank_table["status"] = "completed"
|
|
|
|
# generate server_list for rank table
|
|
rank_table["server_list"] = [] # type: ignore[assignment]
|
|
decode_server_device_info = None
|
|
prefill_server_device_info = {
|
|
"device": [{
|
|
k: v
|
|
for k, v in [(
|
|
"device_id", prefill_metadata.device_id
|
|
), ("device_ip", prefill_metadata.device_ip
|
|
), ("super_device_id",
|
|
prefill_metadata.super_device_id), ("rank_id", "0")]
|
|
if v is not None
|
|
}],
|
|
"server_id":
|
|
prefill_metadata.server_id
|
|
}
|
|
if is_same_server:
|
|
prefill_server_device_info["device"].append( # type: ignore[attr-defined]
|
|
{
|
|
k: v
|
|
for k, v in [(
|
|
"device_id", decode_metadata.device_id
|
|
), ("device_ip", decode_metadata.device_ip
|
|
), ("super_device_id",
|
|
decode_metadata.super_device_id), ("rank_id", "1")]
|
|
if v is not None
|
|
})
|
|
else:
|
|
decode_server_device_info = {
|
|
"device": [{
|
|
k: v
|
|
for k, v in [(
|
|
"device_id", decode_metadata.device_id
|
|
), ("device_ip", decode_metadata.device_ip
|
|
), ("super_device_id",
|
|
decode_metadata.super_device_id), ("rank_id", "1")]
|
|
if v is not None
|
|
}],
|
|
"server_id":
|
|
decode_metadata.server_id
|
|
}
|
|
rank_table["server_list"].append( # type: ignore[attr-defined]
|
|
prefill_server_device_info)
|
|
if decode_server_device_info is not None:
|
|
rank_table["server_list"].append( # type: ignore[attr-defined]
|
|
decode_server_device_info)
|
|
|
|
if self.soc_info == AscendSocVersion.A3:
|
|
# generate super_pod_list for rank table
|
|
super_pod_list = []
|
|
prefill_super_pod_info = {
|
|
"super_pod_id": prefill_metadata.super_pod_id,
|
|
"server_list": [{
|
|
"server_id": prefill_metadata.server_id
|
|
}],
|
|
}
|
|
if is_same_pod and not is_same_server:
|
|
prefill_super_pod_info[
|
|
"server_list"].append( # type: ignore[attr-defined]
|
|
{"server_id": decode_metadata.server_id})
|
|
super_pod_list.append(prefill_super_pod_info)
|
|
if not is_same_pod:
|
|
decode_super_pod_id = {
|
|
"super_pod_id": decode_metadata.super_pod_id,
|
|
"server_list": [{
|
|
"server_id": decode_metadata.server_id
|
|
}],
|
|
}
|
|
super_pod_list.append(decode_super_pod_id)
|
|
rank_table[
|
|
"super_pod_list"] = super_pod_list # type: ignore[assignment]
|
|
logger.info(
|
|
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
|
|
)
|
|
logger.info(f"rank table \n{rank_table}")
|
|
logger.info(f"comm name: {comm_name}")
|
|
logger.info(f"cluster rank info: {cluster_rank_info}")
|
|
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
|
|
json.dumps(rank_table))
|
|
while True:
|
|
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
|
|
if ret == llm_datadist.RegisterMemStatus.OK:
|
|
logger.info(
|
|
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
|
|
)
|
|
break
|
|
elif ret == llm_datadist.RegisterMemStatus.FAILED:
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
|
|
)
|
|
time.sleep(1)
|
|
logger.info("Checking query_register_mem_status again")
|
|
self.linked_cluster.update({remote_cluster_id: comm_id})
|
|
logger.info(f"cached linked cluster: {self.linked_cluster}")
|
|
logger.info(
|
|
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
|
|
)
|
|
return remote_cluster_id
|
|
|
|
def remove_remote_agent(self, cluster_id: int):
|
|
if cluster_id not in self.linked_cluster:
|
|
logger.warning(
|
|
f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list"
|
|
)
|
|
comm_id = self.linked_cluster[cluster_id]
|
|
try:
|
|
self.llm_datadist.unlink(comm_id)
|
|
self.linked_cluster.pop(cluster_id)
|
|
except LLMException:
|
|
logger.error(
|
|
f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment"
|
|
)
|
|
logger.info(
|
|
f"Successfully remove remote client with cluster id {cluster_id} !"
|
|
)
|
|
|
|
def connect_to_remote_agent(self, host: str, port: int) -> int:
|
|
url = f"tcp://{host}:{port}"
|
|
logger.debug(f"Querying metadata from url: {url}")
|
|
msg_encoder = msgspec.msgpack.Encoder()
|
|
msg_send = msg_encoder.encode(
|
|
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
|
|
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
|
logger.info("Try request remote metadata from socket......")
|
|
sock.send(msg_send)
|
|
metadata_bytes = sock.recv()
|
|
decoder = msgspec.msgpack.Decoder()
|
|
metadata = decoder.decode(metadata_bytes)
|
|
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
|
|
logger.info(f"recving metadata: {metadata}")
|
|
cluster_id = self.add_remote_agent(metadata)
|
|
return cluster_id
|
|
|
|
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
|
|
for port in ports:
|
|
url = f"tcp://{host}:{port}"
|
|
logger.debug(f"Sending finished to remote: {url}")
|
|
msg_encoder = msgspec.msgpack.Encoder()
|
|
msg_send = msg_encoder.encode(
|
|
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
|
|
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
|
try:
|
|
sock.send(msg_send)
|
|
logger.debug(
|
|
f"Request id {request_id} finished message send to remote {url}"
|
|
)
|
|
_ = sock.recv()
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to send reqest_id {request_id} to prefill: {e}"
|
|
)
|
|
|
|
def _read_blocks(
|
|
self,
|
|
local_block_ids: list[int],
|
|
remote_block_ids: list[int],
|
|
remote_ip: str,
|
|
remote_port: int,
|
|
remote_engine_id: str,
|
|
request_id: str,
|
|
remote_tp_size: str,
|
|
):
|
|
# if remote_ip not in self.linked_cluster:
|
|
tp_offset = self.tp_rank % int(remote_tp_size)
|
|
remote_cluster_id = self.connect_to_remote_agent(
|
|
remote_ip, remote_port + tp_offset)
|
|
num_local_blocks = len(local_block_ids)
|
|
if num_local_blocks == 0:
|
|
return
|
|
num_remote_blocks = len(remote_block_ids)
|
|
assert num_local_blocks <= num_remote_blocks
|
|
if num_local_blocks < num_remote_blocks:
|
|
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
|
|
|
logger.info(f"remote cluster id is: {remote_cluster_id}")
|
|
if self.use_mla:
|
|
remote_cache_key_k_normed = BlocksCacheKey(
|
|
cluster_id=remote_cluster_id, model_id=0)
|
|
remote_cache_key_k_pe = BlocksCacheKey(
|
|
cluster_id=remote_cluster_id, model_id=1)
|
|
logger.info("Try pull blocks from remote server")
|
|
try:
|
|
self.cache_manager.pull_blocks(
|
|
remote_cache_key_k_normed,
|
|
self.cache[0], # type: ignore[has-type]
|
|
remote_block_ids,
|
|
local_block_ids)
|
|
self.cache_manager.pull_blocks(
|
|
remote_cache_key_k_pe,
|
|
self.cache[1], # type: ignore[has-type]
|
|
remote_block_ids,
|
|
local_block_ids)
|
|
except (TypeError, ValueError):
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
|
)
|
|
except LLMException:
|
|
raise RuntimeError(
|
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
|
)
|
|
elif self.use_sparse:
|
|
remote_cache_key_k_normed = BlocksCacheKey(
|
|
cluster_id=remote_cluster_id, model_id=0)
|
|
remote_cache_key_k_pe = BlocksCacheKey(
|
|
cluster_id=remote_cluster_id, model_id=1)
|
|
remote_cache_key_k_idx = BlocksCacheKey(
|
|
cluster_id=remote_cluster_id, model_id=2)
|
|
logger.info("Try pull blocks from remote server")
|
|
try:
|
|
self.cache_manager.pull_blocks(
|
|
remote_cache_key_k_normed,
|
|
self.cache[0], # type: ignore[has-type]
|
|
remote_block_ids,
|
|
local_block_ids)
|
|
self.cache_manager.pull_blocks(
|
|
remote_cache_key_k_pe,
|
|
self.cache[1], # type: ignore[has-type]
|
|
remote_block_ids,
|
|
local_block_ids)
|
|
self.cache_manager.pull_blocks(
|
|
remote_cache_key_k_idx,
|
|
self.cache[2], # type: ignore[has-type]
|
|
remote_block_ids,
|
|
local_block_ids)
|
|
except (TypeError, ValueError):
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
|
)
|
|
except LLMException:
|
|
raise RuntimeError(
|
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
|
)
|
|
else:
|
|
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
|
logger.info("Try pull blocks from remote server")
|
|
try:
|
|
self.cache_manager.pull_blocks(
|
|
remote_cache_key,
|
|
self.cache, # type: ignore[has-type]
|
|
remote_block_ids,
|
|
local_block_ids)
|
|
except (TypeError, ValueError):
|
|
raise RuntimeError(
|
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
|
)
|
|
except LLMException:
|
|
raise RuntimeError(
|
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
|
)
|
|
remote_ports = list(
|
|
range(remote_port + self.tp_rank,
|
|
remote_port + int(remote_tp_size), self.tp_size))
|
|
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
|
|
with self.thread_lock:
|
|
self.finished_reqs.add(request_id)
|
|
|
|
def get_finished(
|
|
self, finished_req_ids: set[str]
|
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
|
"""Get the finished recving and sending requuests."""
|
|
now = time.perf_counter()
|
|
with self.thread_lock:
|
|
while self.reqs_to_send:
|
|
req_id, expires = next(iter(self.reqs_to_send.items()))
|
|
if now < expires:
|
|
break
|
|
logger.warning(
|
|
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
|
|
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
|
)
|
|
if req_id in self.reqs_to_send:
|
|
self.finished_reqs.add(req_id)
|
|
del self.reqs_to_send[req_id]
|
|
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
|
self.finished_reqs.clear()
|
|
if self.llm_datadist_role == LLMRole.PROMPT:
|
|
return req_ids_to_ret, None
|
|
else:
|
|
return None, req_ids_to_ret
|
|
|
|
|
|
# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
|
@contextlib.contextmanager
|
|
def zmq_ctx(socket_type: Any,
|
|
addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
|
"""Context manager for a ZMQ socket"""
|
|
|
|
ctx: Optional[zmq.Context] = None # type: ignore[name-defined]
|
|
try:
|
|
ctx = zmq.Context() # type: ignore[attr-defined]
|
|
|
|
if socket_type == zmq.ROUTER: # type: ignore[attr-defined]
|
|
socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
|
|
socket.bind(addr)
|
|
elif socket_type == zmq.REQ: # type: ignore[attr-defined]
|
|
socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined]
|
|
socket.connect(addr)
|
|
else:
|
|
raise ValueError(f"Unexpected socket type: {socket_type}")
|
|
|
|
yield socket
|
|
finally:
|
|
if ctx is not None:
|
|
ctx.destroy(linger=0) |