Files
vllm-npu-plugin/vllm_npu/distributed/device_communicators/pyhccl_wrapper.py
2026-02-10 23:08:39 +08:00

254 lines
8.4 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from vllm.logger import logger
from vllm_npu.utils import find_hccl_library
# export types and functions from hccl to Python ===
# for the original hccl definition, please check
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl.h#L90
# https://github.com/EternalLied/cann-hccl-new/blob/64ec6ce2923319caa5df8c3c531e06bdc148ce9c/inc/hccl/hccl_types.h#L48
hcclResult_t = ctypes.c_int
hcclComm_t = ctypes.c_void_p
class hcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 4108)]
aclrtStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
hcclDataType_t = ctypes.c_int
class hcclDataTypeEnum:
hcclInt8 = 0
hcclInt16 = 1
hcclInt32 = 2
hcclFloat16 = 3
hcclFloat32 = 4
hcclInt64 = 5
hcclUint64 = 6
hcclUint8 = 7
hcclUint16 = 8
hcclUint32 = 9
hcclFloat64 = 10
hcclBfloat16 = 11
hcclInt128 = 12
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.hcclInt8
if dtype == torch.uint8:
return cls.hcclUint8
if dtype == torch.int32:
return cls.hcclInt32
if dtype == torch.int64:
return cls.hcclInt64
if dtype == torch.float16:
return cls.hcclFloat16
if dtype == torch.float32:
return cls.hcclFloat32
if dtype == torch.float64:
return cls.hcclFloat64
if dtype == torch.bfloat16:
return cls.hcclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
hcclRedOp_t = ctypes.c_int
class hcclRedOpTypeEnum:
hcclSum = 0
hcclProd = 1
hcclMax = 2
hcclMin = 3
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.hcclSum
if op == ReduceOp.PRODUCT:
return cls.hcclProd
if op == ReduceOp.MAX:
return cls.hcclMax
if op == ReduceOp.MIN:
return cls.hcclMin
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class HCCLLibrary:
exported_functions = [
# const char* HcclGetErrorString(HcclResult code);
Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]),
# HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);
Function("HcclGetRootInfo", hcclResult_t,
[ctypes.POINTER(hcclUniqueId)]),
# HcclResult HcclCommInitRootInfo(
# uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
# note that HcclComm is a pointer type, so the last argument is a pointer to a pointer
Function("HcclCommInitRootInfo", hcclResult_t, [
ctypes.c_int,
ctypes.POINTER(hcclUniqueId),
ctypes.c_int,
ctypes.POINTER(hcclComm_t),
]),
# HcclResult HcclAllReduce(
# void *sendBuf, void *recvBuf, uint64_t count,
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
# aclrtStream stream);
Function("HcclAllReduce", hcclResult_t, [
buffer_type,
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
hcclRedOp_t,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclBroadcast(
# void *buf, uint64_t count,
# HcclDataType dataType, uint32_t root,
# HcclComm comm, aclrtStream stream);
Function("HcclBroadcast", hcclResult_t, [
buffer_type,
ctypes.c_size_t,
hcclDataType_t,
ctypes.c_int,
hcclComm_t,
aclrtStream_t,
]),
# HcclResult HcclCommDestroy(HcclComm comm);
Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the correspongding directory
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_hccl_library()
try:
if so_file not in HCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
HCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = HCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load HCCL library from %s. "
"It is expected if you are not running on Ascend NPUs."
"Otherwise, the hccl library might not exist, be corrupted "
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable HCCL_SO_PATH"
" to point to the correct hccl library path.", so_file,
platform.platform())
raise e
if so_file not in HCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in HCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
HCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = HCCLLibrary.path_to_dict_mapping[so_file]
def hcclGetErrorString(self, result: hcclResult_t) -> str:
return self._funcs["HcclGetErrorString"](result).decode("utf-8")
def HCCL_CHECK(self, result: hcclResult_t) -> None:
if result != 0:
error_str = self.hcclGetErrorString(result)
raise RuntimeError(f"HCCL error: {error_str}")
def hcclGetUniqueId(self) -> hcclUniqueId:
unique_id = hcclUniqueId()
self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](
ctypes.byref(unique_id)))
return unique_id
def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId,
rank: int) -> hcclComm_t:
comm = hcclComm_t()
self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"](
world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)))
return comm
def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
# `datatype` actually should be `hcclDataType_t`
# and `op` should be `hcclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int,
root: int, comm: hcclComm_t,
stream: aclrtStream_t) -> None:
self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype,
root, comm, stream))
def hcclCommDestroy(self, comm: hcclComm_t) -> None:
self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))
__all__ = [
"HCCLLibrary",
"hcclDataTypeEnum",
"hcclRedOpTypeEnum",
"hcclUniqueId",
"hcclComm_t",
"aclrtStream_t",
"buffer_type",
]