mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
196
vllm_npu/distributed/parallel_state.py
Normal file
196
vllm_npu/distributed/parallel_state.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
import vllm_npu.envs as envs_ascend
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
_OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
assert _MC2 is not None, ("mc2 group is not initialized")
|
||||
return _MC2
|
||||
|
||||
|
||||
def get_otp_group() -> GroupCoordinator:
|
||||
assert _OTP is not None, (
|
||||
"output tensor parallel group is not initialized")
|
||||
return _OTP
|
||||
|
||||
|
||||
def get_lmhead_tp_group() -> GroupCoordinator:
|
||||
assert _LMTP is not None, (
|
||||
"lm head tensor parallel group is not initialized")
|
||||
return _LMTP
|
||||
|
||||
|
||||
def get_mlp_tp_group() -> GroupCoordinator:
|
||||
assert _MLP_TP is not None, ("mlp group is not initialized")
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def get_p_tp_group() -> GroupCoordinator:
|
||||
assert _P_TP is not None, (
|
||||
"distributed prefill tensor parallel group is not initialized")
|
||||
return _P_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
|
||||
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
if model_parallel_initialized():
|
||||
return
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
|
||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
global _P_TP
|
||||
assert _P_TP is None, (
|
||||
"distributed prefill tensor parallel group is already initialized")
|
||||
prefill_tensor_model_parallel_size = pd_tp_ratio
|
||||
# divide alltoall groups
|
||||
if pd_head_ratio > 1 and get_current_vllm_config(
|
||||
).kv_transfer_config.is_kv_producer:
|
||||
num_head_replica = get_ascend_config().num_head_replica
|
||||
remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio
|
||||
if num_head_replica <= 1:
|
||||
group_ranks = all_ranks.view(
|
||||
-1, prefill_tensor_model_parallel_size).unbind(0)
|
||||
else:
|
||||
group_ranks = all_ranks.clone().view(
|
||||
parallel_config.data_parallel_size, -1,
|
||||
num_head_replica) # [DP_size, num_head, num_head_replica]
|
||||
group_ranks = group_ranks.permute(0, 2, 1)
|
||||
group_ranks = group_ranks.reshape(
|
||||
-1,
|
||||
group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
||||
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
||||
group_ranks = group_ranks.unsqueeze(-1).view(
|
||||
parallel_config.data_parallel_size, num_head_replica, -1,
|
||||
alltoall_group_size
|
||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||
group_ranks = group_ranks.reshape(-1,
|
||||
alltoall_group_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
local_rank = get_world_group().local_rank
|
||||
num = next(
|
||||
(i for i, ranks in enumerate(group_ranks) if local_rank in ranks),
|
||||
None)
|
||||
_P_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name=f"p_tp_{num}")
|
||||
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
_MC2 = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mc2")
|
||||
if envs_ascend.vllm_npu_ENABLE_MLP_OPTIMIZE:
|
||||
global _MLP_TP
|
||||
assert _MLP_TP is None, (
|
||||
"mlp tensor model parallel group is already initialized")
|
||||
|
||||
mlp_tp = parallel_config.data_parallel_size
|
||||
|
||||
all_ranks_mlp_head = torch.arange(world_size).reshape(
|
||||
-1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa
|
||||
group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_MLP_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="mlp_tp")
|
||||
|
||||
# If oproj tensor parallel size is set, we will create a group for it.
|
||||
otp_size = get_ascend_config().oproj_tensor_parallel_size
|
||||
if otp_size is not None:
|
||||
group_ranks = []
|
||||
global _OTP
|
||||
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
ranks = list(range(i * otp_size, (i + 1) * otp_size))
|
||||
group_ranks.append(ranks)
|
||||
_OTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="otp")
|
||||
|
||||
lmhead_tensor_parallel_size = get_ascend_config(
|
||||
).lmhead_tensor_parallel_size
|
||||
if lmhead_tensor_parallel_size is not None:
|
||||
group_ranks = []
|
||||
global _LMTP
|
||||
num_lmhead_tensor_parallel_groups: int = (world_size //
|
||||
lmhead_tensor_parallel_size)
|
||||
for i in range(num_lmhead_tensor_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * lmhead_tensor_parallel_size,
|
||||
(i + 1) * lmhead_tensor_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
_LMTP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="lmheadtp")
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().world_size
|
||||
|
||||
|
||||
def get_mlp_tensor_model_parallel_rank():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_mlp_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_ascend_model_parallel():
|
||||
global _MC2
|
||||
if _MC2:
|
||||
_MC2.destroy()
|
||||
_MC2 = None
|
||||
|
||||
global _MLP_TP
|
||||
if _MLP_TP:
|
||||
_MLP_TP.destroy()
|
||||
_MLP_TP = None
|
||||
|
||||
global _LMTP
|
||||
if _LMTP:
|
||||
_LMTP.destroy()
|
||||
_LMTP = None
|
||||
|
||||
global _OTP
|
||||
if _OTP:
|
||||
_OTP.destroy()
|
||||
_OTP = None
|
||||
|
||||
global _P_TP
|
||||
if _P_TP:
|
||||
_P_TP.destroy()
|
||||
_P_TP = None
|
||||
Reference in New Issue
Block a user