This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

View File

View File

@@ -0,0 +1,44 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
from abc import abstractmethod
from typing import Any
class EplbAdaptor():
def __init__(self, **args):
pass
@abstractmethod
def get_rank_expert_workload(self):
raise NotImplementedError
@abstractmethod
def get_init_expert_map(self, num_moe_layers: Any) -> Any:
raise NotImplementedError
@abstractmethod
def do_update_expert_map(self, layer_id: Any,
updated_expert_map: Any) -> Any:
raise NotImplementedError
@abstractmethod
def do_update_expert_weight(self, layer_id: Any,
local_expert_to_replace: Any,
buffer_tensor_id: Any) -> Any:
raise NotImplementedError

View File

@@ -0,0 +1,289 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
import json
from typing import Any
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_npu.ascend_config import get_ascend_config
from vllm_npu.eplb.adaptor.abstract_adaptor import EplbAdaptor
class VllmEplbAdaptor(EplbAdaptor):
def __init__(self, model, **args):
super().__init__(**args)
self.model = model
self.rank_id = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_dict = dict(self.model.named_parameters())
if self.model.config.model_type == "qwen3_moe":
self.num_dense_layers = 0
self.global_expert_num = self.model.config.num_experts
else:
self.num_dense_layers = self.model.config.first_k_dense_replace
self.global_expert_num = self.model.config.n_routed_experts
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
self.init_redundancy_expert = get_ascend_config(
).init_redundancy_expert
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
self.expert_map_per_layer = dict(
) # reference to expert map on device for expert map update
self.expert_map_per_layer_cpu = dict(
) # copy of expert map on CPU to avoid device synchronize frequently
for layer_idx in range(self.num_moe_layers):
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
self.model.get_expert_map(self.num_dense_layers + layer_idx)
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
num_buffer_tensor = torch.where(
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
self.buffer_tensor_list: list[list[Any]] = [
[] for _ in range(num_buffer_tensor)
]
self.init_buffer_tensor(num_buffer_tensor)
self.expert_param_per_layer = dict()
self.init_expert_param_per_layer()
self.log2phy_map_per_layer = dict()
for layer_idx in range(self.num_moe_layers):
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
self.all_topk_ids = []
def init_buffer_tensor(self, num_buffer_tensor):
for buffer_id in range(num_buffer_tensor):
for name in self.expert_weight_names:
complete_name = "model.layers." + str(
self.num_dense_layers) + ".mlp.experts." + name
expert_tensor = self.param_dict[complete_name].data[0]
if name in ["w13_weight", "w2_weight"]:
expert_tensor = expert_tensor.clone()
buffer_tensor = torch.empty_like(expert_tensor)
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
def init_expert_param_per_layer(self):
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
".mlp.experts." + self.expert_weight_names[0]].data.shape[0]
for moe_layer_id in range(self.num_moe_layers):
layer_idx = self.num_dense_layers + moe_layer_id
self.expert_param_per_layer[layer_idx] = list()
for local_expert_id in range(num_local_expert):
self.expert_param_per_layer[layer_idx].append([
self.param_dict["model.layers." + str(layer_idx) +
".mlp.experts." +
name].data[local_expert_id]
for name in self.expert_weight_names
])
def get_rank_expert_workload(self) -> torch.Tensor:
self.moe_load = self.model.get_all_moe_loads()
return self.moe_load
def get_init_expert_map(self, num_moe_layers):
expert_map = self.model.get_all_expert_map(num_moe_layers)
if dist.is_initialized():
world_size = dist.get_world_size()
gathered = torch.empty(
(world_size, *expert_map.shape), # [W, L, E]
dtype=expert_map.dtype,
device=expert_map.device)
dist.all_gather_into_tensor(gathered, expert_map)
all_maps = gathered.permute(1, 0, 2)
all_expert_maps = all_maps.cpu()
for layer_idx in range(num_moe_layers):
self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \
all_expert_maps[layer_idx][self.rank_id]
return all_expert_maps
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
try:
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(
expert_map_path)
expert_map_all = self.local2global(expert_map_tensor)
except (TypeError, FileNotFoundError, OSError):
expert_map_all = self.determine_expert_map_all()
for layer_idx in range(num_moe_layers):
if self.model.config.model_type == "qwen3_moe":
self.expert_map_per_layer_cpu[layer_idx] = \
expert_map_all[layer_idx][self.rank_id]
else:
self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \
expert_map_all[layer_idx][self.rank_id]
return expert_map_all
def _expert_file_to_tensor(self, expert_map_path: str):
with open(expert_map_path, "r") as f:
data = json.load(f)
layers_num = data["moe_layer_count"]
gpus_num = data["layer_list"][0]["device_count"]
tensor_data = []
for layer in data["layer_list"]:
device_data = []
for device in layer["device_list"]:
device_data.append(device["device_expert"])
tensor_data.append(device_data)
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
return expert_map_tensor, layers_num, gpus_num
logger.error(f"failed to read expert_map_path: {expert_map_path}")
def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
if self.rank_id == 0:
num_local_experts = expert_maps.max() + 1
expert_maps_local = self.global2local(expert_maps,
num_local_experts)
expert_maps_list = expert_maps_local.tolist()
record: dict[str, Any] = {
"moe_layer_count": len(expert_maps_list),
"layer_list": []
}
for layer_idx, layer_data in enumerate(expert_maps_list):
layer_record: dict[str, Any] = {
"layer_id": layer_idx,
"device_count": len(layer_data),
"device_list": []
}
for device_idx, experts in enumerate(layer_data):
device_record = {
"device_id": device_idx,
"device_expert": experts
}
layer_record["device_list"].append(device_record)
record["layer_list"].append(layer_record)
with open(expert_map_record_path, "w") as f:
json.dump(record, f, indent=4)
def do_update_expert_map(self, layer_id, updated_expert_map):
self.expert_map_per_layer[layer_id].copy_(updated_expert_map)
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
buffer_tensor_id):
for expert_tensor, buffer_tensor in zip(
self.expert_param_per_layer[layer_id][local_expert_to_replace],
self.buffer_tensor_list[buffer_tensor_id]):
expert_tensor.copy_(buffer_tensor)
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
if self.log2phy_map_per_layer[layer_id] is not None:
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
def global2local(self, placement: torch.Tensor,
E_local: int) -> torch.Tensor:
L, G, _ = placement.shape
device = placement.device
pt_local = torch.full((L, G, E_local),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement >= 0
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
slot_idx = placement[l_idx, g_idx, k_idx]
pt_local[l_idx, g_idx, slot_idx] = k_idx
return pt_local
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
L, G, E_local = placement_local.shape
device = placement_local.device
max_id = torch.max(placement_local)
E_global = (max_id + 1).item() if max_id >= 0 else 0
if E_global == 0:
return torch.empty((L, G, 0), dtype=torch.long, device=device)
placement_global = torch.full((L, G, E_global),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement_local >= 0
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
gid_idx = placement_local[l_idx, g_idx, slot_idx]
placement_global[l_idx, g_idx, gid_idx] = slot_idx
return placement_global
def determine_expert_map_all(self):
if self.world_size == 1:
local_ids = torch.arange(self.global_expert_num, dtype=torch.int32)
return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1)
local_num_experts = self.global_expert_num // self.world_size
expert_map_all = torch.full(
(self.num_moe_layers, self.world_size, self.global_expert_num),
-1,
dtype=torch.int32)
for r in range(self.world_size):
if r < self.world_size - 1:
start = r * local_num_experts
end = (r + 1) * local_num_experts
local_count = local_num_experts
else:
start = r * local_num_experts
end = self.global_expert_num
local_count = self.global_expert_num - r * local_num_experts
if r < self.init_redundancy_expert:
local_count += 1
if end < self.global_expert_num:
end += 1
else:
start -= 1
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
self.num_moe_layers, -1)
return expert_map_all

View File

View File

@@ -0,0 +1,134 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from enum import Enum
import torch.distributed as dist
from vllm.logger import logger
class ExpertWeightUpdateState(Enum):
WAITING = 0 # waiting for updated expert_map by EplbWorker
READY = 1 # ready for d2d expert weights updating
TRANSFERRING = 2 # d2d finished and waiting for updating expert_map into model
class D2DExpertWeightLoader:
def __init__(self):
self.comm_op_list = None
self.updated_expert_map = None
self.updated_log2phy_map = None
self.layer_id = -1 # layer id to be updated
self.state = ExpertWeightUpdateState.WAITING
self.recv_expert_list = []
self.mock_flag = True
def set_adator(self, eplb_adaptor):
self.eplb_adaptor = eplb_adaptor
def generate_expert_d2d_transfer_task(self, expert_send_info,
expert_recv_info, updated_expert_map,
layer_id):
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
if self.state != ExpertWeightUpdateState.WAITING:
logger.warning_once(
"current d2d weight update tasks are on-going, cannot accept new weight update task"
)
return
self.updated_expert_map = updated_expert_map
self.layer_id = layer_id
self.comm_op_list = []
for send_info in expert_send_info:
dst_rank, global_expert_id_to_send = send_info
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[
layer_id][global_expert_id_to_send].item()
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
layer_id][local_expert_id]:
src_tensor = src_tensor.clone()
self.comm_op_list.append(
dist.P2POp(dist.isend, src_tensor, dst_rank))
buffer_tensor_id = 0
for recv_info in expert_recv_info:
recv_rank, global_expert_id_to_recv = recv_info
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[
buffer_tensor_id]:
self.comm_op_list.append(
dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
local_expert_to_replace = self.updated_expert_map[
global_expert_id_to_recv].item()
self.recv_expert_list.append(
(local_expert_to_replace, buffer_tensor_id))
buffer_tensor_id += 1
self.state = ExpertWeightUpdateState.READY
def set_log2phy_map(self, log2phy_map):
self.updated_log2phy_map = log2phy_map
def asyn_expert_weight_transfer(self, reqs):
# Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched
if self.state != ExpertWeightUpdateState.READY:
return
# set asynchronous stream for d2d expert weight transfer
if self.comm_op_list:
ret_list = dist.batch_isend_irecv(self.comm_op_list)
reqs.extend(ret_list)
self.state = ExpertWeightUpdateState.TRANSFERRING
def update_expert_map_and_weight(self, reqs):
# Only after send/recv tasks have been luanched, expert_map and weight can be updated
if self.state != ExpertWeightUpdateState.TRANSFERRING:
return
# Waiting for send/recv tasks finish
for req in reqs:
req.wait()
if self.comm_op_list is not None:
self.comm_op_list = None
# update expert_map
self.eplb_adaptor.do_update_expert_map(self.layer_id,
self.updated_expert_map)
# update log2phy_map
self.eplb_adaptor.do_update_log2phy_map(self.layer_id,
self.updated_log2phy_map)
# update expert weight
buffer_tensor_id = 0
for recv_expert_info in self.recv_expert_list:
local_expert_to_replace, buffer_tensor_id = recv_expert_info
self.eplb_adaptor.do_update_expert_weight(self.layer_id,
local_expert_to_replace,
buffer_tensor_id)
logger.info(
f"[EPLB] finished update expert weight for layer: {self.layer_id}")
self.recv_expert_list = []
self.updated_expert_map = None
self.layer_id = -1
self.state = ExpertWeightUpdateState.WAITING
def load_impl(self, old_expert_table, new_expert_table):
raise NotImplementedError

View File

@@ -0,0 +1,189 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
import os.path
import random
import sys
import torch
from vllm.logger import logger
def determine_default_expert_map(global_expert_num, world_size, rank_id,
global_redundant_expert_num):
if world_size == 1:
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
return (global_expert_num, local_ids)
local_num_experts = global_expert_num // world_size
expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32)
if rank_id < world_size - 1:
start = rank_id * local_num_experts
end = (rank_id + 1) * local_num_experts
local_count = local_num_experts
else:
start = rank_id * local_num_experts
end = global_expert_num
local_count = global_expert_num - rank_id * local_num_experts
if isinstance(local_count, int):
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map[start:end] = local_ids
return (local_count, expert_map)
def generate_log2phy_map(expert_map):
num_local_experts = expert_map.max() + 1
log2phy_map = expert_map.clone()
num_ranks, num_global_expert = log2phy_map.shape
row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \
num_global_expert) * num_local_experts
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]
for idx in range(num_global_expert):
positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0]
negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0]
num_rank_holding_expert = positive_rank_idx.size(0)
if num_rank_holding_expert == 0:
log2phy_map[:, idx] = torch.full((num_ranks, ),
0,
dtype=log2phy_map.dtype)
if num_rank_holding_expert == 1:
log2phy_map[negative_rank_idx, idx] = torch.full(
(num_ranks - 1, ),
log2phy_map[positive_rank_idx, idx].item(),
dtype=log2phy_map.dtype)
else:
try:
random_list = [
random.choice(log2phy_map[positive_rank_idx, idx])
for _ in range(num_ranks - num_rank_holding_expert)
]
log2phy_map[negative_rank_idx,
idx] = torch.tensor(random_list,
dtype=log2phy_map.dtype)
except Exception as e:
logger.error(f"Fail to get log2phy_map: {str(e)}")
return log2phy_map
def determine_default_log2phy_map(global_expert_num, world_size, rank_id):
if world_size == 1:
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
log2phy_map_all = generate_log2phy_map(expert_map_all)
return log2phy_map_all[rank_id]
local_num_experts = global_expert_num // world_size
expert_map_all = torch.full((world_size, global_expert_num),
-1,
dtype=torch.int32)
for r in range(world_size):
if r < world_size - 1:
start = r * local_num_experts
end = (r + 1) * local_num_experts
local_count = local_num_experts
else:
start = r * local_num_experts
end = global_expert_num
local_count = global_expert_num - r * local_num_experts
if isinstance(local_count, int):
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[r, start:end] = local_ids
log2phy_map_all = generate_log2phy_map(expert_map_all)
return log2phy_map_all[rank_id]
class EPLBParamUtils:
@staticmethod
def check_iterations(iterations):
if not isinstance(iterations, int):
raise TypeError(f"The {iterations} is not int.")
if iterations <= 0:
raise ValueError(
f"The {iterations} can not less than or equal to 0.")
if iterations > sys.maxsize:
raise ValueError(
f"The {iterations} can not large than {sys.maxsize}")
@staticmethod
def check_dynamic_eplb(dynamic_eplb):
if dynamic_eplb is None:
return
if not isinstance(dynamic_eplb, bool):
raise TypeError("The dynamic_eplb is not bool.")
if dynamic_eplb and os.getenv("DYNAMIC_EPLB", "false") != "true":
raise ValueError(
'Can not enable dynamic_eplb when not export DYNAMIC_EPLB="true".'
)
@staticmethod
def check_expert_map_path(expert_map):
if expert_map is None:
return
if not isinstance(expert_map, str):
raise TypeError("The expert_map is not str.")
if not expert_map.strip():
raise ValueError("The expert_map is not empty.")
_, ext = os.path.splitext(expert_map)
if ext.lower() != ".json":
raise TypeError("The expert_map is not json.")
if not os.path.exists(expert_map):
raise ValueError("The expert_map is not exist.")
try:
with open(expert_map, "w", encoding='utf-8') as f:
f.read()
except Exception as e:
raise IOError(
f"Fail read expert info from {expert_map}, please check the reading permission of {expert_map} : {e}"
)
@staticmethod
def check_expert_map_record_path(expert_map_record_path):
if expert_map_record_path is None:
return
if not isinstance(expert_map_record_path, str):
raise TypeError("The expert_map_record_path is not str.")
if not expert_map_record_path.strip():
raise ValueError("The expert_map_record_path is empty.")
_, ext = os.path.splitext(expert_map_record_path)
if ext.lower() != ".json":
raise TypeError("The expert_map_record_path is not json.")
if os.getenv("EXPERT_MAP_RECORD", "false") != "true":
raise ValueError(
'Can not enable expert_map_record_path when not export EXPERT_MAP_RECORD="true".'
)
try:
with open(expert_map_record_path, "w", encoding='utf-8') as f:
f.write("")
except Exception as e:
raise IOError(
f"Fail write expert info to {expert_map_record_path}, please check the writing permission of {expert_map_record_path} : {e}"
)

View File

@@ -0,0 +1,440 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from multiprocessing import Process, Queue
from typing import Any
import networkx as nx # type: ignore
import numpy as np
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_npu.eplb.core.eplb_utils import generate_log2phy_map
from vllm_npu.eplb.core.policy.policy_factory import (DynamicConfig,
PolicyFactory)
class EplbWorker:
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
self.policy_type = policy_type
self.policy = PolicyFactory.generate_policy(policy_type,
DynamicConfig())
self.shared_dict = shared_dict
self.old_expert_maps = None
self.enable_d2d = enable_d2d
self.rank_id = dist.get_rank()
def do_update(self):
# put data in to queue
# in process self.policy.generate_policy()
# get epxert table && tensor
# async stream
# D2D
# H2D
# Get initial expert_map
torch.set_num_threads(1)
if self.old_expert_maps is None:
self.old_expert_maps = self.get_init_expert_maps()
if self.old_expert_maps is not None:
self.num_local_experts = self.old_expert_maps.max() + 1
else:
raise ValueError("Failed to get expert_maps from shared_dict.")
# Get MOE load information
load_info = self.fetch_and_sum_load_info()
if load_info is None:
return
# Get the updated expert table based on the workload information
old_placement = self.global2local(self.old_expert_maps,
self.num_local_experts)
_, _, new_placement = self.calculate_rebalance_experts(
load_info, old_placement)
if not torch.is_tensor(new_placement):
new_placement = torch.tensor(new_placement)
self.check_expert_placement(old_placement, new_placement)
new_expert_maps = self.local2global(new_placement)
self.update_expert_map(new_expert_maps)
if self.policy_type == 2:
update_info = self.compose_expert_update_info_bipartite(
new_expert_maps, self.old_expert_maps)
else:
update_info = self.compose_expert_update_info_greedy(
new_expert_maps, self.old_expert_maps)
self.old_expert_maps = new_expert_maps
logger.info("EPLB Process compute complete")
packed_update_info = self.pack_update_info(update_info)
return packed_update_info
def check_expert_placement(self, old_placement, new_placement):
num_layers = old_placement.shape[0]
num_ranks = old_placement.shape[1]
for layer_id in range(num_layers):
# check if any logical expert is not placed on any rank
if torch.unique(new_placement[layer_id]).numel() < torch.unique(
old_placement[layer_id]).numel():
logger.error(
f"There exists expert not placed on any rank in layer {layer_id}"
)
new_placement[layer_id] = old_placement[layer_id]
continue
for rank_id in range(num_ranks):
new_placement_check = new_placement[layer_id][rank_id]
old_placement_check = old_placement[layer_id][rank_id]
# check if same logical experts are placed on the same NPU
if new_placement_check.numel() != torch.unique(
new_placement_check).numel():
logger.error(
f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
)
new_placement[layer_id] = old_placement[layer_id]
break
# check if there is any experts movement inside one NPU
expert_not_move = torch.isin(new_placement_check,
old_placement_check)
if not torch.equal(new_placement_check[expert_not_move],
old_placement_check[expert_not_move]):
logger.error(
f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
)
new_placement[layer_id] = old_placement[layer_id]
break
def compose_expert_update_info_bipartite(self, updated_expert_maps_org,
current_expert_maps_org):
# transform numpy array to torch tensor
updated_expert_maps = updated_expert_maps_org.clone()
current_expert_maps = current_expert_maps_org.clone()
updated_expert_maps = np.array(updated_expert_maps)
current_expert_maps = np.array(current_expert_maps)
num_layers = current_expert_maps.shape[0]
for layer_id in range(num_layers):
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
current_expert_maps_this_layer = current_expert_maps[layer_id]
updated_expert_maps_this_layer_org = updated_expert_maps_org[
layer_id]
from typing import Any
expert_send_info_this_layer: dict[Any, Any] = {}
expert_recv_info_this_layer: dict[Any, Any] = {}
# Guard Clause: if there is no expert weight update, avoid subsequent processing
if (np.equal(updated_expert_maps_this_layer,
current_expert_maps_this_layer)).all():
yield (expert_send_info_this_layer,
expert_recv_info_this_layer,
updated_expert_maps_this_layer_org, layer_id)
# Parse expert_ids each rank needs to receive from other ranks
dst_rank_indices, experts_to_recv = np.where(
(current_expert_maps_this_layer == -1)
& (updated_expert_maps_this_layer != -1))
# record src ranks for potential transfer
src_ranks_set = dict()
for idx in range(len(dst_rank_indices)):
expert_id = experts_to_recv[idx].item()
if expert_id not in src_ranks_set:
src_ranks_set[expert_id] = np.where(
current_expert_maps_this_layer[:, expert_id] != -1)[0]
# loop until all experts are scheduled
while len(dst_rank_indices) > 0:
# construct bipartite graph
graph_expert_update: nx.Graph = nx.Graph()
for idx in range(len(dst_rank_indices)):
dst_rank_id = dst_rank_indices[idx].item()
expert_id = experts_to_recv[idx].item()
# add src ranks
src_rank_ids = src_ranks_set[expert_id]
graph_expert_update.add_nodes_from(src_rank_ids,
bipartite=0)
# add dest rank
graph_expert_update.add_node(str(dst_rank_id), bipartite=1)
# add edges
for src_rank_id in src_rank_ids:
graph_expert_update.add_edge(src_rank_id,
str(dst_rank_id))
# graph may not be connected
connected_components = list(
nx.connected_components(graph_expert_update))
all_matches = {}
# matching in this loop
for i, component in enumerate(connected_components):
subgraph = graph_expert_update.subgraph(component)
component_matching = nx.bipartite.maximum_matching(
subgraph)
all_matches.update(component_matching)
for src_rank, dst_rank in all_matches.items():
dst_rank = int(dst_rank)
assert src_rank != dst_rank
if graph_expert_update.nodes[src_rank]['bipartite'] == 0:
# currently not scheduled experts in rank dst_rank
experts_v = experts_to_recv[np.where(
dst_rank_indices == dst_rank)]
# src: src_rank, dest: dst_rank, expert: expert_id
expert_id = np.intersect1d(
experts_v,
np.where(current_expert_maps_this_layer[src_rank]
!= -1))[0]
# record send/rcv pairs
if src_rank not in expert_send_info_this_layer:
expert_send_info_this_layer[src_rank] = []
if dst_rank not in expert_recv_info_this_layer:
expert_recv_info_this_layer[dst_rank] = []
expert_send_info_this_layer[src_rank].append(
(dst_rank, expert_id))
expert_recv_info_this_layer[dst_rank].append(
(src_rank, expert_id))
remove_index = np.where(
np.logical_and(dst_rank_indices == dst_rank,
experts_to_recv == expert_id))
# update
dst_rank_indices = np.delete(dst_rank_indices,
remove_index)
experts_to_recv = np.delete(experts_to_recv,
remove_index)
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
updated_expert_maps_this_layer_org, layer_id)
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
def compose_expert_update_info_greedy(self, updated_expert_maps,
current_expert_maps):
num_layers = current_expert_maps.shape[0]
for layer_id in range(num_layers):
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
current_expert_maps_this_layer = current_expert_maps[layer_id]
expert_send_info_this_layer: dict[Any, Any] = {}
expert_recv_info_this_layer: dict[Any, Any] = {}
# Guard Clause: if there is no expert weight update, avoid subsequent processing
if torch.equal(updated_expert_maps_this_layer,
current_expert_maps_this_layer):
yield (expert_send_info_this_layer,
expert_recv_info_this_layer,
updated_expert_maps_this_layer, layer_id)
# Parse expert_ids each rank needs to receive from other ranks
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \
& (updated_expert_maps_this_layer != -1))
# Parse expert_ids each rank needs to send to other ranks
src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \
& (updated_expert_maps_this_layer == -1))
for idx in range(len(dst_rank_indices)):
dst_rank_id = dst_rank_indices[idx].item()
expert_id = experts_to_recv[idx].item()
if dst_rank_id not in expert_recv_info_this_layer:
expert_recv_info_this_layer[dst_rank_id] = []
if not torch.isin(torch.tensor(expert_id),
experts_to_send).any():
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
candidate_src_rank_indices = torch.where(
current_expert_maps_this_layer[:, expert_id] != -1)[0]
else:
candidate_src_rank_indices = src_rank_indices[
experts_to_send == expert_id]
# TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node...
src_rank_id = candidate_src_rank_indices[0].item()
if src_rank_id not in expert_send_info_this_layer:
expert_send_info_this_layer[src_rank_id] = []
expert_send_info_this_layer[src_rank_id].append(
(dst_rank_id, expert_id))
expert_recv_info_this_layer[dst_rank_id].append(
(src_rank_id, expert_id))
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
updated_expert_maps_this_layer, layer_id)
def calculate_rebalance_experts(self, load_info, old_placement):
"""
Compute `new_map` by calling the `rebalance_experts` method of the policy instance.
"""
if self.old_expert_maps is None:
return False, None, None
changed, priority, new_map = self.policy.rebalance_experts(
old_placement, load_info)
return changed, priority, new_map
def get_init_expert_maps(self):
"""
Read the initial expert_map from shared_dict.
"""
return self.shared_dict.get("expert_maps", None)
def fetch_and_sum_load_info(self):
"""
Each time the subprocess is awakened, read the latest moe_load
(shape: [num_moe_layers, num_experts_per_layer]) from shared_dict.
"""
return self.shared_dict.get("moe_load", None)
def update_expert_map(self, expert_maps):
self.shared_dict["expert_maps"] = expert_maps
def global2local(self, placement: torch.Tensor,
E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
L, G, _ = placement.shape
device = placement.device
pt_local = torch.full((L, G, E_local),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement >= 0
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
slot_idx = placement[l_idx, g_idx, k_idx]
pt_local[l_idx, g_idx, slot_idx] = k_idx
return pt_local
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
L, G, E_local = placement_local.shape
device = placement_local.device
max_id = torch.max(placement_local)
E_global = (max_id + 1).item() if max_id >= 0 else 0
if E_global == 0:
return torch.empty((L, G, 0), dtype=torch.long, device=device)
placement_global = torch.full((L, G, E_global),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement_local >= 0
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
gid_idx = placement_local[l_idx, g_idx, slot_idx]
placement_global[l_idx, g_idx, gid_idx] = slot_idx
return placement_global
def pack_update_info(self, update_info_generator):
"""
Pack a list of update info tuples for efficient IPC.
"""
send_all = []
recv_all = []
maps = []
log2phy_all = []
layer_ids = []
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
send_info_this_rank = send_info[
self.rank_id] if self.rank_id in send_info else []
recv_info_this_rank = recv_info[
self.rank_id] if self.rank_id in recv_info else []
send_all.append(send_info_this_rank)
recv_all.append(recv_info_this_rank)
maps.append(new_expert_map[self.rank_id].numpy().tolist())
log2phy_map = generate_log2phy_map(new_expert_map)
log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist())
layer_ids.append(layer_id)
return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids))
class EplbProcess:
def __init__(self,
shared_dict,
policy_type: int = 0,
enable_d2d: bool = True):
"""
Args:
shared_dict: Cross-process shared dict returned by Manager().dict()
policy_type: Integer passed to PolicyFactory.generate_policy
enable_d2d: Whether to enable D2D loading
"""
self.shared_dict = shared_dict
self.policy_type = policy_type
self.enable_d2d = enable_d2d
self.planner_q: Queue[Any] = Queue()
self.block_update_q: Queue[Any] = Queue(maxsize=1)
# Create EplbWorker instance
self.worker = EplbWorker(self.shared_dict, self.policy_type,
self.enable_d2d)
def worker_process(self, planner_q, block_update_q):
"""
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
"""
while True:
try:
planner_q.get()
packed_update_info = self.worker.do_update()
while True:
if not block_update_q.empty():
continue
block_update_q.put(packed_update_info)
break
except Exception as e:
logger.warning(f"[EPLB subprocess Exiting due to error: {e}",
exc_info=True)
break
def _launch_process(self):
"""
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
"""
proc = Process(target=self.worker_process,
args=(self.planner_q, self.block_update_q),
daemon=True)
proc.start()
return proc

View File

View File

@@ -0,0 +1,42 @@
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
from abc import abstractmethod
class DynamicConfig:
placement_policy = None
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
num_die_per_host = 8 # Number of dies on each host machine
class EplbPolicy:
def __init__(self, config: DynamicConfig):
self.config = config
@abstractmethod
def rebalance_experts(self, current_expert_table, expert_workload):
"""
Pass in the weights and return expert replication and placement under relevant constraints.
INPUT:
current_expert_table: [layerId, rankId, expert_num_i]
expert_workload = expert_table[layer0][rankId][expert_num_i]
RETURNED: (res, expert_table)
res:
1 -- table_changed
0 -- not_changed
expert_table: [layerId, rankId, expert_num_i]
expert_num_i --- [0, MaxExpertPerRank]
expertID = expert_table[layer0][rankId][expert_num_i]
array_values:
[0, 1, 2, 3, 248]
[4, 5, 6, 7, 254]
[8, 9, 10, 11, 71]
...
[252, 253, 254, 255, 0]
"""
pass

View File

@@ -0,0 +1,389 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
from collections import defaultdict
from typing import cast
import numpy as np
from .policy_abstract import DynamicConfig, EplbPolicy
class DynamicTable:
# workload_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
workload_table = None
# placement_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
placement_table = None
class DynamicEplb(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
@staticmethod
def add_redundant(current_expert_table, expert_workload,
num_original_expert):
layer_num, npu_num, experts_per_npu = expert_workload.shape
workload_new = np.zeros((layer_num, num_original_expert))
for layer_idx in range(layer_num):
workload_dict: dict[int, int] = defaultdict(int)
placement_layer = current_expert_table[layer_idx].copy()
workload_layer = expert_workload[layer_idx].copy()
for npu_idx in range(npu_num):
for expert_idx in range(experts_per_npu):
workload_dict[placement_layer[npu_idx][
expert_idx]] += workload_layer[npu_idx][expert_idx]
for expert_idx in range(num_original_expert):
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
return workload_new
@staticmethod
# Split hot (high-load) experts into redundant experts
def original_compute_balanced_pack_redundancy(origin_weights, card_num,
num_redundancy_expert):
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
# Sort based on the second element (the second value of each tuple)
route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [
[] for _ in range(route_expert_num)
]
for i in range(num_redundancy_expert):
sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
weights = [origin_weights[idx] for idx in sorted_indices]
tmp_raw_weight = weights[0][1] * (
len(route_expert_redundancy[weights[0][0]]) + 1)
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
avg_weight = tmp_raw_weight / (
len(route_expert_redundancy[weights[0][0]]) + 1)
weights[0] = (weights[0][0], avg_weight)
origin_weights = weights
# Step 2: Calculate the number of items per box
expert_num = route_expert_num + num_redundancy_expert
items_per_box = expert_num // card_num # Number of items per box
remaining_items = expert_num % card_num # Number of items per box
# Step 3: Initialize card_num boxes with empty lists to store item IDs
boxes: list[list[int]] = [[] for _ in range(card_num)]
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
box_weights = [0] * card_num # To store the total weight of each box
box_counts = [0] * card_num # To store the number of items in each box
index = 0
for i in range(route_expert_num):
redundancy_num = len(route_expert_redundancy[i])
for _ in range(redundancy_num):
cur_weight = 0
for item, weight in origin_weights:
if item == i:
cur_weight = weight
boxes[index].append(i)
boxes_weights[index].append(cur_weight)
box_weights[index] += cur_weight
box_counts[index] += 1
index += 1
sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
origin_weights = [origin_weights[idx] for idx in sorted_indices]
# Step 4: Distribute items into boxes based on weight
for item_id, weight in origin_weights:
# Find the box with the least items but not full
min_box_index = -1
for i in range(card_num):
if item_id in boxes[i]:
continue
# Only choose boxes that still have space (box_counts[i] < items_per_box)
if box_counts[i] < items_per_box or (box_counts[i]
== items_per_box
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
min_box_index = i
# Place the item (id) into the selected box
boxes[min_box_index].append(item_id)
boxes_weights[min_box_index].append(weight)
box_weights[min_box_index] += weight
box_counts[min_box_index] += 1
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter
if box_counts[min_box_index] == (items_per_box +
1) and remaining_items > 0:
remaining_items -= 1
# Step 5: Output each box's contents and total weight
result = []
for i in range(card_num):
result.append({
"box_index": i + 1,
"items": boxes[i], # List of item IDs in the box
"weight": boxes_weights[i],
"total_weight": box_weights[i], # Total weight in this box
"item_count": box_counts[i] # Number of items in the box
})
return result, boxes
# Split hot (high-load) experts into redundant experts
@staticmethod
def compute_balanced_pack_redundancy(origin_weights, card_num,
num_redundancy_expert):
route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [
[] for _ in range(route_expert_num)
]
for i in range(num_redundancy_expert):
sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
weights = [origin_weights[idx] for idx in sorted_indices]
tmp_raw_weight = weights[0][1] * (
len(route_expert_redundancy[weights[0][0]]) + 1)
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
avg_weight = tmp_raw_weight / (
len(route_expert_redundancy[weights[0][0]]) + 1)
weights[0] = (weights[0][0], avg_weight)
origin_weights = weights
expert_num = route_expert_num + num_redundancy_expert
if card_num == 0:
raise RuntimeError("card_num can not be 0.")
items_per_box = expert_num // card_num
remaining_items = expert_num % card_num
boxes: list[list[int]] = [[] for _ in range(card_num)]
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
box_weights = [0] * card_num
box_counts = [0] * card_num
all_weights = np.zeros((expert_num, ), dtype='object')
all_weights[:route_expert_num] = origin_weights
index = route_expert_num
for i in range(route_expert_num):
redundancy_num = len(route_expert_redundancy[i])
for _ in range(redundancy_num):
for item, weight in origin_weights:
if item == i:
all_weights[index] = (item, weight)
index += 1
sorted_indices = np.argsort([t[1] for t in all_weights],
kind='stable')[::-1]
all_weights = [all_weights[idx] for idx in sorted_indices]
for item_id, weight in all_weights:
min_box_index = -1
for i in range(card_num):
if box_counts[i] < items_per_box or (box_counts[i]
== items_per_box
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
if item_id not in boxes[i]:
min_box_index = i
boxes[min_box_index].append(item_id)
boxes_weights[min_box_index].append(weight)
box_weights[min_box_index] += weight
box_counts[min_box_index] += 1
if box_counts[min_box_index] == (items_per_box +
1) and remaining_items > 0:
remaining_items -= 1
result = []
for i in range(card_num):
result.append({
"box_index": i + 1,
"items": boxes[i],
"weight": boxes_weights[i],
"total_weight": box_weights[i],
"item_count": box_counts[i]
})
return result, boxes
# Scheme without redundant experts
@staticmethod
def compute_balanced_pack(origin_weights, card_num):
sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1]
weights = origin_weights[sorted_indices]
expert_num = len(weights)
if card_num == 0:
raise RuntimeError("card_num can not be 0.")
items_per_box = expert_num // card_num
remaining_items = expert_num % card_num
boxes: list[list[int]] = [[] for _ in range(card_num)]
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
box_weights = [0] * card_num
box_counts = [0] * card_num
for item_id, weight in weights:
min_box_index = -1
for i in range(card_num):
if box_counts[i] < items_per_box or (box_counts[i]
== items_per_box
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
min_box_index = i
boxes[min_box_index].append(item_id)
boxes_weights[min_box_index].append(weight)
box_weights[min_box_index] += weight
box_counts[min_box_index] += 1
if box_counts[min_box_index] == (items_per_box +
1) and remaining_items > 0:
remaining_items -= 1
result = []
for i in range(card_num):
result.append({
"box_index": i + 1,
"items": boxes[i],
"weight": boxes_weights[i],
"total_weight": box_weights[i],
"item_count": box_counts[i]
})
return result, boxes
@staticmethod
def get_redundant_num(npu_num, counts):
redundant_num_each_npu: int = np.sum(counts - 1)
return redundant_num_each_npu
@staticmethod
def calculate_max_heat_per_layer(workload_table, layer_num):
max_heat_per_layer: list[float] = []
for layer_idx in range(layer_num):
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
max_heat_per_layer.append(np.max(npu_heats_now))
return max_heat_per_layer
@staticmethod
def constraint_expert_local_exchange(current_expert_table,
global_deployment):
for layer_id in range(len(global_deployment)):
for card_id in range(len(global_deployment[layer_id])):
current_list = [
int(x) for x in current_expert_table[layer_id][card_id]
]
new_list = [
int(x) for x in global_deployment[layer_id][card_id]
]
num = len(new_list)
new_index = [-1] * num
new_result = [-1] * num
remaining_elements = []
for i in range(num):
flag = True
for j in range(num):
if new_list[i] == current_list[j] and new_index[
j] == -1:
new_index[j] = 0
new_result[j] = current_list[j]
flag = False
break
if flag:
remaining_elements.append(new_list[i])
index = 0
for k in range(num):
if new_result[k] == -1:
new_result[k] = remaining_elements[index]
index += 1
global_deployment[layer_id][card_id] = new_result
return global_deployment
def rebalance_experts(self, current_expert_table, expert_workload):
info = DynamicTable()
info.workload_table = np.array(expert_workload)
info.placement_table = np.array(current_expert_table)
assert info.workload_table is not None
layer_num, num_npus, experts_per_npu = info.workload_table.shape
assert info.placement_table is not None
row = cast(np.ndarray, info.placement_table[0])
expert_ids, counts = np.unique(row, return_counts=True)
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
num_original_expert = len(expert_ids)
layer_workloads = self.add_redundant(info.placement_table,
info.workload_table,
num_original_expert)
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
info.workload_table, layer_num)
npu_heat_all_origin = sum(max_heat_per_layer_before)
# Perform load balancing and deploy redundant experts
layer_num = layer_workloads.shape[0]
expert_num = layer_workloads.shape[1]
# Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards
if num_original_expert != expert_num:
raise ValueError(
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
)
if num_npus <= 0:
raise ValueError("the number of NPUs must be greater than 0")
if num_npus < num_redundancy_expert:
raise ValueError(
f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}"
)
# Number of experts deployed on each card includes one redundant expert
global_deployment: list[list[list[int]]] = [[[]
for _ in range(num_npus)]
for _ in range(layer_num)]
# Iterate to obtain the placement strategy for each layer, taking computational balance into account
max_heat_per_layer_after = np.zeros([layer_num])
for layer in range(layer_num):
# Get the expert IDs and their corresponding workloads for the current layer;
# workloads need to be normalized, and one redundant expert is added per card
weights = np.zeros((expert_num, ), dtype='object')
for expert_id, workload_weight in enumerate(
layer_workloads[layer]):
weights[expert_id] = (expert_id, workload_weight)
# Obtain the globally balanced placement strategy for each layer
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
weights, num_npus, num_redundancy_expert)
global_deployment[layer] = layer_deployment
max_heat_per_layer_after[layer] = max(
result, key=lambda x: x['total_weight'])['total_weight']
new_global_deployment = self.constraint_expert_local_exchange(
current_expert_table, global_deployment)
# Obtain the priority of each layer
layer_changed_ratio = []
for layer_idx in range(layer_num):
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] /
max_heat_per_layer_before[layer_idx])
per_layer_priority = np.argsort(layer_changed_ratio)
npu_heat_all_after = sum(max_heat_per_layer_after)
change = 0
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
change = 1
return change, per_layer_priority, np.array(
new_global_deployment).tolist()

View File

@@ -0,0 +1,771 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
from abc import abstractmethod
from collections import defaultdict
import numpy as np
class DynamicConfig:
placement_policy = None
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
num_die_per_host = 8 # Number of dies on each host machine
class EplbPolicy:
def __init__(self, config: DynamicConfig):
self.config = config
@abstractmethod
def rebalance_experts(self, current_expert_table, expert_workload):
"""
Pass in the weights and return expert replication and placement under relevant constraints.
INPUT:
current_expert_table: [layerId, rankId, expert_num_i]
expert_workload = expert_table[layer0][rankId][expert_num_i]
RETURNED: (res, expert_table)
res:
1 -- table_changed
0 -- not_changed
expert_table: [layerId, rankId, expert_num_i]
expert_num_i --- [0, MaxExpertPerRank]
expertID = expert_table[layer0][rankId][expert_num_i]
array_values:
[0, 1, 2, 3, 248]
[4, 5, 6, 7, 254]
[8, 9, 10, 11, 71]
...
[252, 253, 254, 255, 0]
"""
pass
class DynamicTable:
# workload_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
workload_table = None
# placement_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
placement_table = None
class DynamicEplbV2(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
@staticmethod
def safe_divide(a, b):
if b == 0:
print("Division by zero is not allowed")
return 0
return a / b
@staticmethod
def safe_exact_divide(a, b):
if b == 0:
print("Division by zero is not allowed")
return 0
return a // b
@staticmethod
def safe_mod(a, b):
if b == 0:
print("Division by zero is not allowed")
return 0
return a % b
@staticmethod
def add_redundant(current_expert_table, expert_workload,
num_original_expert):
layer_num, npu_num, experts_per_npu = expert_workload.shape
workload_new = np.zeros((layer_num, num_original_expert))
for layer_idx in range(layer_num):
workload_dict: dict[int, int] = defaultdict(int)
placement_layer = current_expert_table[layer_idx].copy()
workload_layer = expert_workload[layer_idx].copy()
for npu_idx in range(npu_num):
for expert_idx in range(experts_per_npu):
workload_dict[placement_layer[npu_idx][
expert_idx]] += workload_layer[npu_idx][expert_idx]
for expert_idx in range(num_original_expert):
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
return workload_new
@staticmethod
def get_redundant_num(npu_num, counts):
redundant_num_each_npu: int = int(np.sum(counts - 1))
return redundant_num_each_npu
@staticmethod
def calculate_max_heat_per_layer(workload_table, layer_num):
max_heat_per_layer: list[float] = []
for layer_idx in range(layer_num):
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
max_heat_per_layer.append(np.max(npu_heats_now))
return max_heat_per_layer
def calculate_initial_imbalance(self, global_deployment,
new_layer_workloads):
device_num = global_deployment.shape[1]
layer_imbalance = []
expert_num = np.zeros_like(new_layer_workloads)
for layer_id, layer in enumerate(global_deployment):
for device in layer:
for expert_id in device:
expert_num[layer_id][expert_id] += 1
for layer_id, layer in enumerate(global_deployment):
cur_layer_max_workload = 0
total_workload = 0
for box in layer:
box_workload = 0
for expert_id in box:
update_workload = self.safe_divide(
new_layer_workloads[layer_id][expert_id],
expert_num[layer_id][expert_id])
box_workload += update_workload
total_workload += update_workload
if cur_layer_max_workload < box_workload:
cur_layer_max_workload = box_workload
cur_layer_imbalance = self.safe_divide(
cur_layer_max_workload,
(self.safe_divide(total_workload, device_num)))
layer_imbalance.append(cur_layer_imbalance)
return layer_imbalance
def compute_redundant_assignments(self, base_experts,
num_redundant_experts, num_experts):
redundant_assignments: list[list[int]] = [[]
for _ in range(num_experts)]
current_weights = base_experts.copy()
for i in range(num_redundant_experts):
sorted_indices = np.argsort([w for _, w in current_weights],
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices]
target_expert = sorted_weights[0]
expert_id, original_weight = target_expert
current_redundancy = len(redundant_assignments[expert_id])
new_avg_weight = self.safe_divide(
original_weight * (current_redundancy + 1),
(current_redundancy + 2))
redundant_assignments[expert_id].append(num_experts + i)
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
sorted_indices = np.argsort([w for _, w in current_weights],
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices]
return redundant_assignments, sorted_weights
def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos,
num_experts, num_exist_expert,
device_assignments, device_counts,
expert_from_device,
com_between_devices):
current_weights = np.zeros((num_experts, ), dtype='object')
for expert_id, workload_weight in enumerate(layer_workloads):
current_weights[expert_id] = (expert_id, workload_weight)
devices_with_slots = []
for device_id, device_rendun_pos in enumerate(rendun_pos):
if len(device_rendun_pos) != 0:
devices_with_slots.append(device_id)
while devices_with_slots:
sorted_indices = np.argsort([w for _, w in current_weights],
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices]
for index, target_weight in enumerate(sorted_weights):
expert_id, original_weight = target_weight
if original_weight == -1:
print("Error:Redundant expert failure re-occurred")
redundancy_successful = True
break
redundancy_successful = False
for cur_device_id in devices_with_slots:
if expert_id not in device_assignments[cur_device_id]:
pos = rendun_pos[cur_device_id].pop()
if len(rendun_pos[cur_device_id]) == 0:
devices_with_slots = [
device_id for device_id in devices_with_slots
if device_id != cur_device_id
]
device_assignments[cur_device_id][pos] = expert_id
device_counts[cur_device_id] += 1
communication_box_index = expert_from_device[expert_id]
com_between_devices[cur_device_id][
communication_box_index] = expert_id
new_weight = self.safe_divide(
(original_weight * num_exist_expert[expert_id]),
(num_exist_expert[expert_id] + 1))
sorted_weights[index] = (expert_id, new_weight)
num_exist_expert[expert_id] += 1
redundancy_successful = True
break
if redundancy_successful:
break
sorted_indices = np.argsort([id for id, _ in sorted_weights],
kind='stable')
sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
return sorted_weights, device_assignments, device_counts, com_between_devices
@staticmethod
def prepare_expert_list(base_experts, redundant_assignments,
num_redundant_experts):
redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
index = 0
num_experts = len(redundant_assignments)
for expert_id in range(num_experts):
for _ in redundant_assignments[expert_id]:
redundant_expert_list[index] = (expert_id,
next(w
for eid, w in base_experts
if eid == expert_id))
index += 1
sorted_indices = np.argsort([w for _, w in redundant_expert_list],
kind='stable')[::-1]
return [redundant_expert_list[i] for i in sorted_indices]
@staticmethod
def non_redundant_expert_information(origin_deployment, updated_weights,
rendun_pos):
device_num = len(origin_deployment)
num_experts_per_device = origin_deployment.shape[1]
device_assignments = [[-1 for _ in range(num_experts_per_device)]
for _ in range(device_num)]
device_weights = [[0 for _ in range(num_experts_per_device)]
for _ in range(device_num)]
device_loads = [0] * device_num
device_counts = [0] * device_num
for device_id, device in enumerate(origin_deployment):
for index, expert_id in enumerate(device):
if index in rendun_pos[device_id]:
continue
device_assignments[device_id][index] = expert_id
cur_weight = next(
weight for expert_id_of_weight, weight in updated_weights
if expert_id_of_weight == expert_id)
device_weights[device_id][index] = cur_weight
device_loads[device_id] += cur_weight
device_counts[device_id] += 1
return device_assignments, device_weights, device_loads, device_counts
def recomputing_initial_weight(self, layer_workloads, device_assignments):
num_all_experts = [0] * len(layer_workloads)
for device in device_assignments:
for expert_id in device:
if expert_id != -1:
num_all_experts[expert_id] += 1
cur_layer_workload = []
for expert_id, weight in enumerate(layer_workloads):
if num_all_experts[expert_id] == 0:
cur_layer_workload.append(-1)
else:
cur_layer_workload.append(
self.safe_divide(weight, num_all_experts[expert_id]))
return cur_layer_workload, num_all_experts
def distribute_redun_experts(self, layer_workloads, device_assignments,
device_weights, device_loads, device_counts,
redundant_expert_list, expert_from_device,
num_experts, rendun_pos):
num_devices = len(device_assignments)
com_between_devices: list[dict[int,
int]] = [{} for _ in range(num_devices)]
for expert_id, weight in redundant_expert_list:
candidate = -1
for dev_id in range(num_devices):
if len(rendun_pos[dev_id]) == 0:
continue
if expert_id in device_assignments[dev_id]:
continue
if candidate == -1 or device_loads[dev_id] < device_loads[
candidate]:
candidate = dev_id
if candidate != -1:
pos = rendun_pos[candidate].pop()
device_assignments[candidate][pos] = expert_id
device_weights[candidate][pos] = weight
device_loads[candidate] += weight
device_counts[candidate] += 1
communication_box_index = expert_from_device[expert_id]
com_between_devices[candidate][
communication_box_index] = expert_id
if any(sublist for sublist in rendun_pos):
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(
layer_workloads, device_assignments)
update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments(
cur_layer_workload, rendun_pos, num_experts, num_exist_expert,
device_assignments, device_loads, expert_from_device,
com_between_devices)
device_loads = [0] * len(device_counts)
for device_id, device in enumerate(device_assignments):
for index, expert_id in enumerate(device):
device_weights[device_id][index] = update_workload[
expert_id]
device_loads[device_id] += update_workload[expert_id]
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
def redundancy_again(self, layer_workloads, origin_weights,
origin_deployment, expert_from_device, num_node,
is_node_redundant, rendun_pos):
num_experts = len(origin_weights)
if is_node_redundant:
num_experts = num_experts * num_node
num_redundant_experts = 0
for rank_empty_pos in rendun_pos:
num_redundant_experts += len(rank_empty_pos)
redundant_assignments, updated_weights = self.compute_redundant_assignments(
origin_weights, num_redundant_experts, num_experts)
redundant_expert_list = self.prepare_expert_list(
updated_weights, redundant_assignments, num_redundant_experts)
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
origin_deployment, updated_weights, rendun_pos)
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts(
layer_workloads, device_assignments, device_weights, device_loads,
device_counts, redundant_expert_list, expert_from_device,
num_experts, rendun_pos)
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
@staticmethod
def generate_allocation_report(device_assignments, device_weights,
device_loads, device_counts):
report = []
max_load = 0.0
for dev_id in range(len(device_assignments)):
current_load = device_loads[dev_id]
max_load = max(max_load, current_load)
report.append({
"device_id": dev_id + 1,
"assigned_experts": device_assignments[dev_id],
"expert_weights": device_weights[dev_id],
"total_load": current_load,
"expert_count": device_counts[dev_id]
})
return report, max_load
@staticmethod
def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id,
next_device_id, cur_layer_result, com_between_devices):
cur_device_deployment = cur_layer_result[cur_device_id][
'assigned_experts']
next_device_deployment = cur_layer_result[next_device_id][
'assigned_experts']
cur_device_weight = cur_layer_result[cur_device_id]['expert_weights']
next_device_weight = cur_layer_result[next_device_id]['expert_weights']
cur_expert_id = cur_device_deployment[cur_exchange_index]
next_expert_id = next_device_deployment[next_exchange_index]
cur_device_deployment[cur_exchange_index] = next_expert_id
next_device_deployment[next_exchange_index] = cur_expert_id
cur_expert_weight = cur_device_weight[cur_exchange_index]
next_expert_weight = next_device_weight[next_exchange_index]
cur_device_weight[cur_exchange_index] = next_expert_weight
next_device_weight[next_exchange_index] = cur_expert_weight
cur_layer_result[cur_device_id][
'total_load'] += next_expert_weight - cur_expert_weight
cur_layer_result[next_device_id][
'total_load'] += cur_expert_weight - next_expert_weight
com_between_devices[cur_device_id][next_device_id] = next_expert_id
com_between_devices[next_device_id][cur_device_id] = cur_expert_id
def redundant_expert_deployment(self, layer_workloads, original_deployment,
expert_from_device, node_num,
is_node_redundant, rendun_pos):
device_num, per_device_expert_num = original_deployment.shape
route_expert_num = layer_workloads.shape[0]
per_node_device_num = self.safe_exact_divide(device_num, node_num)
per_node_route_expert_num = per_node_device_num * (
per_device_expert_num - 1)
weights = np.zeros((route_expert_num, ), dtype='object')
for expert_id, workload_weight in enumerate(layer_workloads):
weights[expert_id] = (expert_id, workload_weight)
if is_node_redundant:
device_assignments = []
device_weights = []
device_loads = []
device_counts = []
com_between_devices = []
for node_id in range(node_num):
cur_node_weights = weights[node_id *
per_node_route_expert_num:(node_id +
1) *
per_node_route_expert_num]
cur_original_deployment = original_deployment[
node_id * per_node_device_num:(node_id + 1) *
per_node_device_num]
cur_node_rendun_pos = rendun_pos[node_id *
per_node_device_num:(node_id +
1) *
per_node_device_num]
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again(
layer_workloads, cur_node_weights, cur_original_deployment,
expert_from_device, node_num, is_node_redundant,
cur_node_rendun_pos)
device_assignments += cur_device_assignments
device_weights += cur_device_weights
device_loads += cur_device_loads
device_counts += cur_device_counts
com_between_devices += cur_com_between_devices
else:
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again(
layer_workloads, weights, original_deployment,
expert_from_device, node_num, is_node_redundant, rendun_pos)
report, max_load = self.generate_allocation_report(
device_assignments, device_weights, device_loads, device_counts)
return report, max_load, com_between_devices
@staticmethod
def two_device_exchange_experts(cur_device_result, exchange_device_result,
cur_exchanged_expert_id,
next_exchanged_expert_id, ave_workload,
increment, num_redundancy_expert):
cur_device_weight = cur_device_result['expert_weights']
next_device_weight = exchange_device_result['expert_weights']
cur_device_expert_id = cur_device_result['assigned_experts']
next_device_expert_id = exchange_device_result['assigned_experts']
cur_device_total_weight = cur_device_result['total_load']
next_device_total_weight = exchange_device_result['total_load']
max_weight = max(cur_device_total_weight, next_device_total_weight)
cur_exchange_index = -1
next_exchange_index = -1
for index, weight in enumerate(cur_device_weight):
for next_index, next_weight in enumerate(next_device_weight):
change_flag = True
if (cur_device_expert_id[index] in next_device_expert_id
or next_device_expert_id[next_index]
in cur_device_expert_id):
change_flag = False
if (cur_device_expert_id[index] not in cur_exchanged_expert_id
) and (next_device_expert_id[next_index]
not in next_exchanged_expert_id) and change_flag:
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
exchange_max_weight = max(
cur_total_weight_after_exchange,
next_total_weight_after_exchange)
if exchange_max_weight < max_weight and (
max_weight -
exchange_max_weight) >= (ave_workload * increment):
max_weight = exchange_max_weight
cur_exchange_index = index
next_exchange_index = next_index
return cur_exchange_index, next_exchange_index
def expert_exchange_between_devices(self,
ave_workload,
increment,
cur_layer_result,
com_between_devices,
num_redundancy_expert,
node_idx=0,
per_node_device_num=0,
is_node_redundant=False):
if is_node_redundant:
cur_devices_result = cur_layer_result[node_idx *
per_node_device_num:
(node_idx + 1) *
per_node_device_num]
else:
cur_devices_result = cur_layer_result
devices_total_weight = []
for device in cur_devices_result:
devices_total_weight.append(
(device['total_load'], device['device_id'] - 1))
exchange_frequency = 100
while exchange_frequency > 0:
exchange_frequency -= 1
devices_total_weight.sort(key=lambda x: x[0])
max_weight_device_id = devices_total_weight[-1][1]
exchange = False
for index in range(0, len(devices_total_weight) - 1):
min_weight_device_id = devices_total_weight[index][1]
if min_weight_device_id not in com_between_devices[
max_weight_device_id]:
cur_exchanged_expert_id = list(
com_between_devices[max_weight_device_id].values())
next_exchanged_expert_id = list(
com_between_devices[min_weight_device_id].values())
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
cur_layer_result[max_weight_device_id],
cur_layer_result[min_weight_device_id],
cur_exchanged_expert_id, next_exchanged_expert_id,
ave_workload, increment, num_redundancy_expert)
if cur_exchange_index != -1:
self.exchange_expert(cur_exchange_index,
next_exchange_index,
max_weight_device_id,
min_weight_device_id,
cur_layer_result,
com_between_devices)
devices_total_weight[-1] = (
cur_layer_result[max_weight_device_id]
['total_load'], max_weight_device_id)
devices_total_weight[index] = (
cur_layer_result[min_weight_device_id]
['total_load'], min_weight_device_id)
exchange = True
break
if not exchange:
break
def exchange_experts(self, layer_result, layer_com_between_devices,
num_nodes, device_num, is_node_redundant,
ave_workload, increment, num_redundancy_expert,
org_deployment):
global_deployment = []
if is_node_redundant:
per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
for node_idx in range(num_nodes):
self.expert_exchange_between_devices(
ave_workload, increment, layer_result,
layer_com_between_devices, num_redundancy_expert, node_idx,
per_node_device_num, is_node_redundant)
else:
self.expert_exchange_between_devices(ave_workload, increment,
layer_result,
layer_com_between_devices,
num_redundancy_expert)
max_workload = 0
for box in layer_result:
global_deployment.append(box['assigned_experts'])
if max_workload < box['total_load']:
max_workload = box['total_load']
global_deployment = np.array(global_deployment)
return global_deployment, max_workload
def count_elements(self, lst):
count = 0
for item in lst:
if isinstance(item, list):
count += self.count_elements(item)
else:
count += 1
return count
@staticmethod
def constraint_expert_local_exchange(current_expert_table,
global_deployment):
for layer_id in range(len(global_deployment)):
for card_id in range(len(global_deployment[layer_id])):
current_list = [
int(x) for x in current_expert_table[layer_id][card_id]
]
new_list = [
int(x) for x in global_deployment[layer_id][card_id]
]
num = len(new_list)
new_index = [-1] * num
new_result = [-1] * num
remaining_elements = []
for i in range(num):
flag = True
for j in range(num):
if new_list[i] == current_list[j] and new_index[
j] == -1:
new_index[j] = 0
new_result[j] = current_list[j]
flag = False
break
if flag:
remaining_elements.append(new_list[i])
index = 0
for k in range(num):
if new_result[k] == -1:
new_result[k] = remaining_elements[index]
index += 1
global_deployment[layer_id][card_id] = new_result
return global_deployment
def rebalance_experts(self,
current_expert_table,
expert_workload,
is_node_redundant=False,
increment=0.01):
info = DynamicTable()
info.workload_table = expert_workload.numpy()
info.placement_table = current_expert_table.numpy()
assert info.workload_table is not None
layer_num, num_npus, experts_per_npu = info.workload_table.shape
expert_ids, counts = np.unique(info.placement_table[0],
return_counts=True)
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
num_original_expert = len(expert_ids)
layer_workloads = self.add_redundant(info.placement_table,
info.workload_table,
num_original_expert)
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
info.workload_table, layer_num)
npu_heat_all_origin = sum(max_heat_per_layer_before)
num_node = self.safe_exact_divide(num_npus, 8)
layer_num = layer_workloads.shape[0]
expert_num = layer_workloads.shape[1]
expert_from_device = np.zeros((layer_num, num_original_expert))
if num_original_expert != expert_num:
raise ValueError(
f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})"
)
if num_npus <= 0:
raise ValueError("The number of NPUs must be greater than 0")
if num_npus < num_redundancy_expert:
raise ValueError(
f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})"
)
global_deployment: list[list[list[int]]] = [[[]
for _ in range(num_npus)]
for _ in range(layer_num)]
layer_initial_imbalance = self.calculate_initial_imbalance(
info.placement_table, layer_workloads)
max_heat_per_layer_after = np.zeros([layer_num])
sum_num = 0
for layer in range(layer_num):
# print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer])
if layer_initial_imbalance[layer] < 1.01:
global_deployment[layer] = info.placement_table[layer]
continue
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]),
num_npus)
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
existing_experts = set()
for device_id, device in enumerate(info.placement_table[layer]):
for index, expert_id in enumerate(device):
if expert_id not in existing_experts:
existing_experts.add(expert_id)
expert_from_device[layer][expert_id] = device_id
else:
rendun_pos[device_id].append(index)
result, max_workload, com_between_devices = self.redundant_expert_deployment(
layer_workloads[layer], info.placement_table[layer],
expert_from_device[layer], num_node, is_node_redundant,
rendun_pos)
# print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload))
global_deployment[layer], new_max_workload = self.exchange_experts(
result, com_between_devices, num_node, num_npus,
is_node_redundant, ave_workload, increment,
num_redundancy_expert, info.placement_table[layer])
# print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload))
for device_id in range(num_npus):
com_between_devices[device_id] = {
key: value
for key, value in com_between_devices[device_id].items()
}
sum_num += self.count_elements(com_between_devices[device_id])
max_heat_per_layer_after[layer] = max(
result, key=lambda x: x['total_load'])['total_load']
layer_changed_ratio = []
for layer_idx in range(layer_num):
layer_changed_ratio.append(
self.safe_divide(max_heat_per_layer_after[layer_idx],
max_heat_per_layer_before[layer_idx]))
per_layer_priority = np.argsort(layer_changed_ratio)
npu_heat_all_after = sum(max_heat_per_layer_after)
change = 0
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
change = 1
new_global_deployment = self.constraint_expert_local_exchange(
current_expert_table, global_deployment)
return change, per_layer_priority, np.array(
new_global_deployment).tolist()

View File

@@ -0,0 +1,33 @@
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
from .policy_abstract import DynamicConfig, EplbPolicy
from .policy_dynamic_ep import DynamicEplb
from .policy_dynamic_ep_v2 import DynamicEplbV2
from .policy_flashlb import FlashLB
from .policy_random import RandomLoadBalance
class PolicyFactory:
@staticmethod
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
policy = {
# Constraint applying Dynamic EPLB policy V2:
# If there exists redundant expert:
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
# Applying greedy d2d expert weight update composing
0:
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
1:
DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
2:
DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
3:
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
}
policy_class = policy.get(policy_type, RandomLoadBalance)
policy_instance = policy_class(config)
if policy_type == 3:
policy_instance.warm_up()
return policy_instance

View File

@@ -0,0 +1,651 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
import logging
from collections import deque
from typing import Dict
import numpy as np
import torch
from numba import njit # type: ignore
from .policy_abstract import DynamicConfig, EplbPolicy
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
@njit
def compute_piece_counts(X, P, stage_weights):
n_stage, N = X.shape
S = P - N
pieces = np.ones(N, dtype=np.int32)
unit = X / pieces # unit[i, j] = X[i, j] / pieces[j]
for _ in range(S):
deltas = np.zeros(N, dtype=np.float32)
for i in range(n_stage):
# Find top1 and top2
idx1 = -1
idx2 = -1
val1 = -1.0
val2 = -1.0
for j in range(N):
v = unit[i, j]
if v > val1:
val2 = val1
idx2 = idx1
val1 = v
idx1 = j
elif v > val2:
val2 = v
idx2 = j
origin = unit[i, idx1]
secv = unit[i, idx2]
alt = X[i, idx1] / (pieces[idx1] + 1)
delta = origin - (alt if alt > secv else secv)
deltas[idx1] += delta * stage_weights[i] if np.any(
delta) != 0 else stage_weights[i]
max_idx = np.argmax(deltas)
pieces[max_idx] += 1
for i in range(n_stage):
unit[i, max_idx] = X[i, max_idx] / pieces[max_idx]
# Compute max load
max_load = 0.0
for j in range(N):
total = 0.0
for i in range(n_stage):
total += unit[i, j]
if total > max_load:
max_load = total
return pieces
@njit
def jsq_placement(X, pieces, M, stage_weights):
n_stage, N = X.shape
total_piece = pieces.sum()
num_per_group = total_piece // M
# 1. Compute unit_hotness
unit_hotness = np.empty((n_stage, N), dtype=np.float32)
for i in range(N):
if pieces[i] > 0:
for s in range(n_stage):
unit_hotness[s, i] = X[s, i] / pieces[i]
else:
for s in range(n_stage):
unit_hotness[s, i] = 0.0
# 2. Sort by total hotness
scores = np.zeros(N, dtype=np.float32)
for i in range(N):
for s in range(n_stage):
scores[i] += unit_hotness[s, i]
idx = np.argsort(-scores)
# 3. Initialization
loads = np.zeros((n_stage, M), dtype=np.float32)
dev_phy_exp_n = np.zeros(M, dtype=np.int32)
deployment = -np.ones((M, num_per_group), dtype=np.int32)
dep_ptr = np.zeros(M, dtype=np.int32)
# 4. Main loop
for t in range(N):
i = idx[t]
used_device = list()
for _ in range(pieces[i]):
# 4.1 Construct w vector
w = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
w[s] = unit_hotness[s, i]
# 4.2 Compute stage-level maximum load
stage_max = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
max_val = loads[s, 0]
for k in range(1, M):
if loads[s, k] > max_val:
max_val = loads[s, k]
stage_max[s] = max_val
# 4.3 Compute denominator
denom = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
sum_tmp = 0.0
for j in range(M):
sum_tmp += loads[s, j] + w[s]
denom[s] = sum_tmp / M + 1e-2
# 4.4 Find best device j
best_j = -1
best_val = 1e30
for j in range(M):
if dev_phy_exp_n[j] >= num_per_group:
continue
if j in used_device:
continue
score = 0.0
for s in range(n_stage):
tmp_sj = loads[s, j] + w[s]
numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
score += stage_weights[s] * (numer_sj / denom[s])
if score < best_val:
best_val = score
best_j = j
if best_j == -1:
continue
used_device.append(best_j)
# 4.5 Update status
for s in range(n_stage):
loads[s, best_j] += w[s]
ptr = dep_ptr[best_j]
deployment[best_j, ptr] = i
dep_ptr[best_j] += 1
dev_phy_exp_n[best_j] += 1
# Handle remaining -1 values: fill with random elements from range(N) not in current column
for rank in range(M):
for col in range(num_per_group):
if deployment[rank, col] == -1:
# Get elements already in current column
current_rank_elements = set(deployment[rank, :])
# Filter elements from range(N) not in current column
available = [
x for x in range(N) if x not in current_rank_elements
]
# Randomly select an available element to fill
if len(available) > 0:
rand_idx = np.random.randint(0, len(available))
deployment[rank, col] = available[rand_idx]
elif N > 0:
# All unique experts are already in this rank's column, so we can pick any expert randomly.
deployment[rank, col] = np.random.randint(0, N)
return deployment
@njit
def slice_values(X, pieces):
total_len = 0
for i in range(X.shape[0]):
total_len += pieces[i]
result = np.empty(total_len, dtype=np.float32)
idx = 0
for i in range(X.shape[0]):
val = X[i] / pieces[i]
for _ in range(pieces[i]):
result[idx] = val
idx += 1
return result
@njit
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
simulated_deployment, stage_weights):
n_stage, N = X.shape
num_group = P // M
X_all = np.zeros(N, dtype=np.float32)
for i in range(n_stage):
for j in range(N):
X_all[j] += X[i, j]
sort_idx = np.argsort(np.negative(X_all))
X_sorted = X[:, sort_idx]
unit_load = np.empty(N, dtype=np.float32)
for j in range(N):
unit_load[j] = X_all[j] / simulated_pieces[j]
flat_deployment = simulated_deployment.reshape(-1)
simulated_load = np.zeros(M, dtype=np.float32)
for i in range(flat_deployment.shape[0]):
simulated_load[i // (flat_deployment.shape[0] //
M)] += unit_load[flat_deployment[i]]
slice_vals = slice_values(X_all, simulated_pieces)
sorted_slices = np.sort(slice_vals)[::-1]
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M
cumulative_slices_used = np.zeros(N, dtype=np.int32)
acc = 0
for i in range(N):
acc += simulated_pieces[sort_idx[i]]
cumulative_slices_used[i] = acc
group_boundary_indices = np.zeros(num_group, dtype=np.int32)
for i in range(1, num_group + 1):
for j in range(N):
if cumulative_slices_used[j] >= i * M:
group_boundary_indices[i - 1] = j
break
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
slices_used_per_group[0] = group_boundary_indices[0]
for i in range(1, num_group):
slices_used_per_group[
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
slices_used_per_group = M - slices_used_per_group
loads = np.zeros(M, dtype=np.float32)
pieces = np.zeros(N, dtype=np.int32)
num_remain_slice = P - N
current_idx = 0
for g in range(num_group):
window = X_sorted[:, current_idx:current_idx + 2 * M]
low = max(0, current_idx + M - N)
high = min(num_remain_slice, M - 1)
while (high - low) > 1:
mid = int((high + low) // 2)
keep = M - mid
current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M,
stage_weights)
current_pieces = np.maximum(current_pieces, 1)
current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice)
current_loads = loads + current_slice_sorted
current_max: np.float32 = np.max(current_loads)
current_min: np.float32 = np.min(current_loads)
current_slope = (current_max - current_min) / M
next_slope: np.float32 = np.max(simulated_slopes[current_idx +
keep:])
if abs(current_slope) > abs(next_slope):
low = mid
else:
high = mid
S = high
keep = M - S
current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M, stage_weights)
for i in range(keep):
pieces[sort_idx[current_idx + i]] = current_pieces[i]
current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice)
loads += current_slice_sorted
loads = np.sort(loads)[::-1]
current_idx += keep
num_remain_slice -= S
return pieces
@njit
def compute_objective(deployment, X, pieces):
M, P = deployment.shape
loads = np.zeros(M)
for i in range(M):
for j in range(P):
expert = deployment[i, j]
if pieces[expert] == 0:
continue
loads[i] += X[expert] / pieces[expert]
mean_load = np.mean(loads)
max_load: np.float32 = np.max(loads)
obj = max_load / mean_load
return obj, loads
@njit
def auto_fix_new_placement(old_placement, new_placement):
"""
Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both
old_placement and new_placement remain in their original positions from old_placement.
New elements (unique to new_placement) will fill the remaining empty positions.
Args:
old_placement: Old deployment matrix with shape (num_ranks, num_experts)
new_placement: New deployment matrix to be fixed, must have the same shape as old_placement
Returns:
fixed_new: adjusted version of the new_placement matrix
"""
num_ranks, num_experts = old_placement.shape
fixed_new = np.empty_like(new_placement)
max_expert_old = old_placement.max() if num_experts > 0 else 0
max_expert_new = new_placement.max() if num_experts > 0 else 0
max_expert = max(max_expert_old, max_expert_new)
for rank_id in range(num_ranks):
old_row = old_placement[rank_id]
new_row = new_placement[rank_id]
index_array = np.full((max_expert + 1, num_experts),
-1,
dtype=np.int32)
count_array = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts):
val = old_row[idx]
if val >= 0 and val <= max_expert:
pos = count_array[val]
index_array[val, pos] = idx
count_array[val] += 1
old_counter = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts):
val = old_row[idx]
if val >= 0 and val <= max_expert:
old_counter[val] += 1
retain_elements = np.empty(num_experts, dtype=new_placement.dtype)
new_elements = np.empty(num_experts, dtype=new_placement.dtype)
retain_ptr = 0
new_ptr = 0
for val in new_row:
if val >= 0 and val <= max_expert and old_counter[val] > 0:
retain_elements[retain_ptr] = val
retain_ptr += 1
old_counter[val] -= 1
else:
new_elements[new_ptr] = val
new_ptr += 1
current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype)
for i in range(retain_ptr):
val = retain_elements[i]
if val >= 0 and val <= max_expert:
pos = count_array[val] - 1
if pos >= 0:
idx = index_array[val, pos]
current_fixed[idx] = val
count_array[val] -= 1
empty_indices = np.empty(num_experts, dtype=np.int32)
empty_ptr = 0
for idx in range(num_experts):
if current_fixed[idx] == -1:
empty_indices[empty_ptr] = idx
empty_ptr += 1
for i in range(new_ptr):
if i < empty_ptr:
current_fixed[empty_indices[i]] = new_elements[i]
fixed_new[rank_id] = current_fixed
return fixed_new
class FlashLB(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
self.par_history: Dict[int, float] = {}
self.hotness_window: Dict[int, deque[float]] = {}
self.max_stage_window = (config.max_stage_window if hasattr(
config, "max_stage_window") else 1)
self.buffer_expert_layer_num = (
config.buffer_expert_layer_num if hasattr(
config, "buffer_expert_layer_num") else 58)
self.threshold_ratio = (config.threshold_ratio if hasattr(
config, "threshold_ratio") else 0)
def compute_expert_hotness(self, num_of_expert: int,
deployment: np.ndarray, rank_load: np.ndarray):
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
deployment_flat = deployment.ravel()
rank_load_flat = rank_load.ravel()
np.add.at(hotness, deployment_flat, rank_load_flat)
return hotness
def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray):
n_stage, N = hotness.shape
if np.any(deployment < 0):
print(f"Invalid deployment with negative values: {deployment}")
raise ValueError("Deployment table contains negative values.")
counts = np.bincount(deployment.reshape(-1), minlength=N)
unit_hotness = np.divide(hotness,
counts,
out=np.zeros_like(hotness, dtype=float),
where=counts != 0)
stage_par = np.zeros(n_stage)
for i in range(n_stage):
stage_load = unit_hotness[i][deployment].sum(-1)
stage_par[i] = stage_load.max() / stage_load.mean()
return stage_par.mean()
def group_based_adaptive_bloating(self,
X,
P,
M,
stage_weights=None,
recorsive=False):
n_stage, N = X.shape
if stage_weights is None:
stage_weights = np.ones(n_stage, dtype=np.float32)
if recorsive:
(
simulated_deployment,
simulated_pieces,
) = self.group_based_adaptive_bloating(X,
P,
M,
stage_weights,
recorsive=False)
else:
simulated_pieces = compute_piece_counts(X, P, stage_weights)
simulated_deployment = jsq_placement(X, simulated_pieces, M,
stage_weights)
pieces = group_based_adaptive_bloating_kernel(
X.astype(np.float32),
P,
M,
simulated_pieces.astype(np.int32),
simulated_deployment.astype(np.int32),
stage_weights.astype(np.float32),
)
deployment = jsq_placement(X, pieces, M, stage_weights)
X_all = X.sum(0)
unit_load = np.divide(X_all,
pieces,
out=np.zeros_like(X_all, dtype=float),
where=pieces != 0)
load = unit_load[deployment].sum(-1)
sim_unit_load = X_all / simulated_pieces
sim_load = sim_unit_load[simulated_deployment].sum(-1)
if load.max() > sim_load.max():
return simulated_deployment, simulated_pieces
return deployment, pieces
def need_update(self, current_par, layer_id=0):
threshold = self.par_history.get(layer_id, 0.0)
return current_par >= self.threshold_ratio * threshold
def compute_stage_weight(self, hotness):
n_stage = hotness.shape[0]
stage_weights = np.zeros(n_stage)
for i in range(n_stage):
stage_weights[i] = hotness[i].sum()
stage_weights = stage_weights / stage_weights.max()
return stage_weights
def rebalance_layer(self, deployment, hotness, layer_id=0):
num_rank, expert_per_rank = deployment.shape
num_expert = np.unique(deployment.reshape(-1)).shape[0]
num_of_redundant_expert = num_rank * expert_per_rank - num_expert
current_par = self.compute_rank_load(deployment, hotness)
if not self.need_update(current_par, layer_id):
return deployment, current_par, current_par
stage_weights = self.compute_stage_weight(hotness)
new_deployment, _ = self.group_based_adaptive_bloating(
hotness,
num_expert + num_of_redundant_expert,
num_rank,
stage_weights,
recorsive=False,
)
if np.any(new_deployment < 0):
print(f"{new_deployment=}")
new_par = self.compute_rank_load(new_deployment, hotness)
return new_deployment, new_par, current_par
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
for layer in range(num_layer):
if layer not in self.hotness_window:
self.hotness_window[layer] = deque(
maxlen=self.max_stage_window)
hotness = self.compute_expert_hotness(num_expert,
deployment[layer],
rank_load[layer])
self.hotness_window[layer].append(hotness)
def compress_by_avg_pooling_fast_nd(self, arr, m):
n, d = arr.shape
idx = (np.arange(n) * m // n)
result = np.zeros((m, d))
counts = np.zeros((m, 1))
np.add.at(result, idx, arr)
np.add.at(counts, idx, 1)
return result / counts
def rebalance_experts(self, current_expert_table, expert_workload):
current_deployment = np.array(current_expert_table)
expert_workload = np.array(expert_workload)
expert_workload += 1
num_layer = expert_workload.shape[0]
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
self.register_hotness(current_deployment, expert_workload, num_layer,
num_expert)
new_deployment = current_deployment.copy()
layers_need_update = np.arange(num_layer)
new_par = np.zeros(layers_need_update.shape[0])
current_par = np.zeros(layers_need_update.shape[0])
for i, layer in enumerate(layers_need_update):
hotness = np.array(self.hotness_window[layer])
if hotness.shape[0] > self.max_stage_window:
hotness = self.compress_by_avg_pooling_fast_nd(
hotness, self.max_stage_window)
(
new_deployment[layer],
new_par[i],
current_par[i],
) = self.rebalance_layer(current_deployment[layer],
hotness,
layer_id=layer)
priority = new_par / current_par
priority_idx = np.argsort(priority)
priority_idx = priority_idx[priority[priority_idx] <
1][:self.buffer_expert_layer_num]
if np.all(expert_workload == 1):
for _, layer in enumerate(layers_need_update):
self.hotness_window[layer].pop()
return False, np.array([], dtype=int), current_deployment
change = len(priority_idx) > 0
if change:
for idx in priority_idx:
self.par_history[layers_need_update[idx]] = new_par[idx]
layers_need_update = priority_idx
deployment = current_deployment
for layer in layers_need_update:
deployment[layer] = auto_fix_new_placement(
current_deployment[layer], new_deployment[layer])
return change, layers_need_update, deployment
def generate_layered_experts(num_layers=58,
layer_shape=(32, 9),
expert_min=0,
expert_max=255):
"""
Generate expert deployment matrix meeting the following conditions:
- Total of num_layers layers
- Each layer has shape layer_shape (32,9)
- Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer
Args:
num_layers: Number of layers, default 58
layer_shape: Shape of a single layer, default (32,9)
expert_min: Minimum expert ID, default 0
expert_max: Maximum expert ID, default 255
Returns:
torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1])
"""
# 1. Basic parameter calculation
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
layer_total = layer_shape[0] * layer_shape[
1] # Total elements in a single layer: 32*9=288
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
assert layer_total >= expert_num, (
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, "
"cannot cover all experts")
# 3. Generate layers one by one
layers = []
for _ in range(num_layers):
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
full_experts = torch.arange(expert_min,
expert_max + 1,
dtype=torch.int64) # shape (256,)
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
extra_experts = torch.randint(expert_min,
expert_max + 1,
size=(extra_slots, ),
dtype=torch.int64) # shape (32,)
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
layer_flat = torch.cat([full_experts, extra_experts],
dim=0) # shape (288,)
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
shuffle_idx = torch.randperm(layer_flat.shape[0])
layer_shuffled = layer_flat[shuffle_idx]
# 3.4 Reshape to layer_shape (32,9)
layer = layer_shuffled.reshape(layer_shape)
layers.append(layer)
# 4. Stack all layers to get the final tensor
return torch.stack(layers, dim=0) # shape (58,32,9)
def warm_up():
exam_config = DynamicConfig()
exam_config.ep_worldsize = 32
exam_config.num_die_per_host = 16
algo = FlashLB(exam_config)
# Generate target tensor
expert_tensor = generate_layered_experts(num_layers=58,
layer_shape=(32, 9))
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))

View File

@@ -0,0 +1,30 @@
# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
import copy
import random
from .policy_abstract import DynamicConfig, EplbPolicy
random.seed(42)
class RandomLoadBalance(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
def rebalance_experts(self, current_expert_table, expert_workload):
new_table = copy.deepcopy(current_expert_table)
num_layers = len(current_expert_table)
for i in range(num_layers):
# randomly choose two card
# indices = random.sample(range(num_card), 2)
indices = [3, 1]
# swap redundant experts
expert_id_to_exchange = new_table[i][indices[0]][-1].clone()
new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1]
new_table[i][indices[1]][-1] = expert_id_to_exchange
return 1, [-i for i in range(num_layers)], new_table

View File

@@ -0,0 +1,209 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator.
import numpy
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import logger
from vllm_npu.eplb.core.eplb_utils import EPLBParamUtils
from vllm_npu.eplb.core.eplb_worker import EplbProcess
class EplbUpdator:
def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
process):
self.ascend_config = ascend_config
self.init_eplb(self.ascend_config.expert_map_path, process)
self.eplb_loader = loader
self.eplb_process = eplb_process
self.shared_dict = self.eplb_process.shared_dict
def set_adaptor(self, adaptor):
self.adaptor = adaptor
self.num_moe_layers = self.adaptor.num_moe_layers
self.global_expert_num = self.adaptor.global_expert_num
def init_eplb(self, expert_map_path, process):
self.rank_id = dist.get_rank()
self.num_expert_load_gather = 10
self.periodic_load_gather = True
self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update
EPLBParamUtils.check_iterations(self.num_iterations_eplb_update)
self.expert_map_path = expert_map_path
self.expert_map_record_path = self.ascend_config.expert_map_record_path
try:
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
self.num_expert_load_gather = self.num_iterations_eplb_update
self.periodic_load_gather = False
except Exception:
self.num_expert_load_gather = self.num_iterations_eplb_update
self.periodic_load_gather = False
self.expert_map_initialized = False
self.gate_eplb = self.ascend_config.gate_eplb
self.reqs = []
self.update_info_all = []
self.cur_iterations: torch.int64 = 0
self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations
EPLBParamUtils.check_iterations(self.num_wait_worker_iterations)
self.process = process
logger.info(
f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
def update_iteration(self):
self.cur_iterations += 1
if self.cur_iterations == (self.num_iterations_eplb_update + \
self.num_wait_worker_iterations + self.num_moe_layers):
if self.expert_map_record_path is not None:
self.adaptor._export_tensor_to_file(
self.shared_dict["expert_maps"],
self.expert_map_record_path)
self.adaptor.model.clear_all_moe_loads()
if not self.gate_eplb:
self.cur_iterations = 0
def get_update_info_flag(self):
return self.cur_iterations == (self.num_iterations_eplb_update +
self.num_wait_worker_iterations - 1)
def wakeup_eplb_worker_flag(self):
return self.cur_iterations == (self.num_iterations_eplb_update - 1)
def update_expert_weight_flag(self):
weight_update_counter = self.cur_iterations - (
self.num_iterations_eplb_update + self.num_wait_worker_iterations)
return (weight_update_counter >= 0
and weight_update_counter < self.num_moe_layers)
def get_init_expert_map(self):
try:
if not self.expert_map_initialized:
self.shared_dict[
"expert_maps"] = self.adaptor.get_init_expert_map_from_file(
self.num_moe_layers, self.expert_map_path)
self.expert_map_initialized = True
except Exception as e:
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}",
exc_info=True)
def wakeup_eplb_worker(self):
self.eplb_process.planner_q.put(1)
def forward_before(self):
if self.update_expert_weight_flag():
(expert_send_info, expert_recv_info, updated_expert_map,
log2phy_map, layer_id) = self.update_info_all.pop(0)
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
updated_expert_map_this_rank = torch.from_numpy(
numpy.array(updated_expert_map))
self.eplb_loader.generate_expert_d2d_transfer_task(
expert_send_info, expert_recv_info,
updated_expert_map_this_rank,
layer_id + self.adaptor.num_dense_layers)
# set asynchronous stream for d2d expert weight update
self.reqs = []
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
def take_update_info_from_eplb_process(self):
# Batch after eplb process being triggered, get update info provided by eplb process
if self.get_update_info_flag():
self.update_info_all = self.eplb_process.block_update_q.get()
def forward_end(self):
if self.wakeup_eplb_worker_flag():
self.compute_and_set_moe_load(is_clear=True)
self.wakeup_eplb_worker()
if self.update_expert_weight_flag(
) and self.expert_map_record_path is None:
self.eplb_loader.update_expert_map_and_weight(self.reqs)
self.update_iteration()
def compute_and_set_moe_load(self, is_clear=False):
local_load = self.adaptor.get_rank_expert_workload()
self._gather_buffer = None
if dist.is_initialized():
self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)
dist.all_gather_into_tensor(self._gather_buffer, local_load)
moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
else:
moe_load = local_load.unsqueeze(1)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
return moe_load
def warm_up_eplb(self):
self.get_init_expert_map()
self.compute_and_set_moe_load()
src_tensor = torch.empty((1, ), device=self.device)
self_rank = dist.get_rank()
comm_op_list = []
for dst_rank in range(self.world_size):
if dst_rank == self_rank:
continue
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
for src_rank in range(self.world_size):
if src_rank == self_rank:
continue
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
if comm_op_list:
reqs = dist.batch_isend_irecv(comm_op_list)
for req in reqs:
req.wait()
def shutdown(self):
"""
Clean up the EPLB process.
"""
if self.process.is_alive():
self.process.terminate()
self.process.join()
logger.info("[ModelRunner] EPLB process terminated")

77
vllm_npu/eplb/utils.py Normal file
View File

@@ -0,0 +1,77 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register.
import types
import torch
def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map()
def get_log2phy_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
def get_all_expert_map(self, num_moe_layers):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
return torch.stack(all_loads, dim=0)
def get_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
return all_moe_loads
def clear_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()
def model_register(model, model_config):
model.get_expert_map = types.MethodType(get_expert_map, model)
model.get_log2phy_map = types.MethodType(get_log2phy_map, model)
model.get_all_expert_map = types.MethodType(get_all_expert_map, model)
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
config = model_config.hf_config
if config.model_type == "qwen3_moe":
model.num_moe_layers = config.num_hidden_layers
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
model.num_dense_layers = config.first_k_dense_replace
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
else:
raise NotImplementedError("EPLB is not supported.")