mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
119
vllm_npu/ops/expert_load_balancer.py
Normal file
119
vllm_npu/ops/expert_load_balancer.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class ExpertLoadBalancer(object):
|
||||
|
||||
def __init__(self, expert_map_path, num_experts):
|
||||
self.expert_map_path = expert_map_path
|
||||
self.num_experts = num_experts
|
||||
self.tensor_data = []
|
||||
self.expert_map_tensor, self.layers_num, self.ranks_num = (
|
||||
self._expert_file_to_tensor())
|
||||
self.global_expert_num = num_experts + self.get_global_redundant_expert_num(
|
||||
)
|
||||
self.expert_placement_map = self.generate_expert_placement_map()
|
||||
|
||||
def _expert_file_to_tensor(self):
|
||||
with open(self.expert_map_path, "r") as f:
|
||||
data = json.load(f)
|
||||
layers_num = data["moe_layer_count"]
|
||||
gpus_num = data["layer_list"][0]["device_count"]
|
||||
|
||||
for layer in data["layer_list"]:
|
||||
device_data = []
|
||||
for device in layer["device_list"]:
|
||||
device_data.append(device["device_expert"])
|
||||
self.tensor_data.append(device_data)
|
||||
expert_map_tensor = torch.tensor(self.tensor_data, dtype=torch.int32)
|
||||
return expert_map_tensor, layers_num, gpus_num
|
||||
|
||||
def generate_index_dicts(self, tensor_2d):
|
||||
dict_list = []
|
||||
current_idx = 0
|
||||
|
||||
for row in tensor_2d:
|
||||
value_to_index = {}
|
||||
for i in range(row.size(0)):
|
||||
value = row[i].item()
|
||||
value_to_index[value] = current_idx + i
|
||||
dict_list.append(value_to_index)
|
||||
current_idx += row.size(0)
|
||||
|
||||
return dict_list
|
||||
|
||||
def generate_expert_placement_map(self):
|
||||
expert_placement_map = torch.full(
|
||||
(self.layers_num, self.ranks_num, self.num_experts),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for layer_id in range(self.layers_num):
|
||||
for gpu_id in range(self.ranks_num):
|
||||
e_ids = self.expert_map_tensor[layer_id, gpu_id]
|
||||
expert_placement_map[layer_id, gpu_id,
|
||||
e_ids] = torch.arange(len(e_ids),
|
||||
dtype=torch.int32)
|
||||
return expert_placement_map
|
||||
|
||||
def generate_log2phy_expert_map(self, layer_id):
|
||||
concatenated = torch.flatten(self.expert_map_tensor[layer_id])
|
||||
rank_expert_to_global = self.generate_index_dicts(
|
||||
self.expert_map_tensor[layer_id])
|
||||
result_dict: Dict[int, List[int]] = {}
|
||||
for idx, value in enumerate(concatenated):
|
||||
key = value.item()
|
||||
if key not in result_dict:
|
||||
result_dict[key] = []
|
||||
result_dict[key].append(idx)
|
||||
|
||||
log2phy_map = torch.full((self.ranks_num, self.num_experts),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
for rank in range(self.ranks_num):
|
||||
for key in result_dict:
|
||||
indices_in_concat = result_dict[key]
|
||||
if key in rank_expert_to_global[rank]:
|
||||
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
|
||||
else:
|
||||
chosen_index = random.choice(indices_in_concat)
|
||||
log2phy_map[rank][key] = chosen_index
|
||||
return log2phy_map
|
||||
|
||||
def get_rank_placement_map(self, layer_id, rank_id):
|
||||
layer_expert_map = self.expert_placement_map[layer_id]
|
||||
rank_expert_map = layer_expert_map[rank_id].to(
|
||||
torch.npu.current_device())
|
||||
rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item()
|
||||
return rank_local_expert_num, rank_expert_map
|
||||
|
||||
def get_rank_log2phy_map(self, layer_id, rank_id):
|
||||
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id)
|
||||
return layer_log2phy_map[rank_id]
|
||||
|
||||
def get_global_redundant_expert_num(self):
|
||||
global_redundant_expert_num = (
|
||||
len(self.expert_map_tensor[0][0]) * self.ranks_num -
|
||||
self.num_experts)
|
||||
return global_redundant_expert_num
|
||||
|
||||
def check_expert_map_tensor(self):
|
||||
if dist.is_initialized():
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
all_expert_maps = [None for _ in range(world_size)]
|
||||
dist.all_gather_object(all_expert_maps, self.tensor_data)
|
||||
for rank_id, expert_map_tensor in enumerate(all_expert_maps):
|
||||
if self.tensor_data != expert_map_tensor:
|
||||
raise ValueError(
|
||||
f"The expert map of rank{rank} is not equal to rank{rank_id}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"The expert maps of all ranks are inconsistency: {e}")
|
||||
Reference in New Issue
Block a user