mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
大改
This commit is contained in:
0
vllm_npu/ops/moe/__init__.py
Normal file
0
vllm_npu/ops/moe/__init__.py
Normal file
113
vllm_npu/ops/moe/comm_utils.py
Normal file
113
vllm_npu/ops/moe/comm_utils.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
a2a_out = input_.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
if event:
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
return input_, a2a_out, handle
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
"""Gather tensors and concatenate along the first dimension.
|
||||
|
||||
Args:
|
||||
input_tensor (torch.Tensor):
|
||||
A tensor to be gathered.
|
||||
output_split_sizes (List[int], optional):
|
||||
A list specifying the sizes of the output splits along the first dimension.
|
||||
If None, equal splitting is assumed. Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gathered tensor.
|
||||
"""
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(
|
||||
input_,
|
||||
group,
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
277
vllm_npu/ops/moe/experts_selector.py
Normal file
277
vllm_npu/ops/moe/experts_selector.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
global_num_experts: int = -1):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
|
||||
Args:
|
||||
router_logits: router logits of shape (num_tokens, hidden_size).
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
top_k: number of top k experts.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
indices_type: dtype of indices
|
||||
global_num_experts: Global number of experts.
|
||||
|
||||
Returns:
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_weights is None:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _renormalize_topk_weights(
|
||||
topk_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
):
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _select_expert_use_group_topk(
|
||||
topk_weights: torch.Tensor, topk_group: Optional[int],
|
||||
renormalize: bool, top_k: int, num_expert_group: Optional[int],
|
||||
e_score_correction_bias: Optional[torch.Tensor]):
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_weights = topk_weights
|
||||
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids = None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _native_select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = router_logits.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights = router_logits.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
if use_grouped_topk:
|
||||
return _select_expert_use_group_topk(
|
||||
topk_weights=topk_weights,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if custom_routing_function is not None:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
520
vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py
Normal file
520
vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_npu.utils import enable_sp
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
in distributed environments. Subclasses implement specific communication strategies
|
||||
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
|
||||
broadcasting, and reduction across TP/DP/EP groups.
|
||||
|
||||
Attributes:
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
- Slicing across tensor-parallel ranks
|
||||
- Broadcasting across data-parallel ranks
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||
- processed router_logits (may be recomputed or broadcasted)
|
||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
- Reducing or scattering across DP ranks
|
||||
- Unpadding to original token count
|
||||
- Applying all-reduce across TP/EP if requested
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
|
||||
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
|
||||
"""
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using MC2 (Memory-Centric Communication).
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""
|
||||
Restore original TP configuration.
|
||||
vLLM flattens TP and DP into a single dimension; this method recovers
|
||||
the true TP world size and rank for correct tensor slicing.
|
||||
"""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
2. Pad `hidden_states` and `router_logits` to target length if needed.
|
||||
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
|
||||
4. Split and return corresponding `mc2_mask`.
|
||||
|
||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
mc2_mask = split_mc2_mask[self.tp_rank]
|
||||
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
if pad_size > 0 and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
return hidden_states, router_logits, mc2_mask
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
|
||||
2. Unpad to original token count if padding was applied.
|
||||
3. Return tensor with shape [original_num_tokens, hidden_size].
|
||||
|
||||
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
# All-gather across TP group
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
# Unpad if necessary
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""Restore original TP configuration (same as MC2)."""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
2. If TP > 1, split along token dim and select current TP rank's slice.
|
||||
3. Save splits for later all-gather in finalize.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
2. Unpad to original token count.
|
||||
3. Return [original_num_tokens, hidden_size] tensor.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
|
||||
There are two sets of prepare and finalize:
|
||||
1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
|
||||
we gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
The communication and calculation process is as follows (AG, AR and RS
|
||||
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
|
||||
|
||||
Attn → TP AR → DP AG → MoE → DP RS → TP AR
|
||||
|
||||
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
|
||||
the above process becomes:
|
||||
|
||||
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
|
||||
|
||||
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
|
||||
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
|
||||
|
||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
2. Pad local tensors to that size.
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
def _finalize_with_ep_group(self,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
|
||||
1. Reduce_results is False usually happens when models have shared experts and need to
|
||||
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
|
||||
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
|
||||
result of shared experts.
|
||||
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
|
||||
here, then skip allreudce in FusedMoe.
|
||||
"""
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states, True)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
2. Slice to original local token count.
|
||||
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [original_local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using Naive Multicast (point-to-point broadcast).
|
||||
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
|
||||
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
|
||||
"""
|
||||
|
||||
def _naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
"""
|
||||
Naive multicast implementation:
|
||||
1. Create global buffer sized by total tokens across DP.
|
||||
2. Current rank copies its slice into its designated buffer region.
|
||||
3. Each rank broadcasts its slice to all others via P2P.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
|
||||
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Global tensor [total_tokens, hidden_size]
|
||||
"""
|
||||
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
# Copy local slice into buffer
|
||||
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
|
||||
# Broadcast each slice to all ranks
|
||||
for idx in range(self.moe_config.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
2. Multicast hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
router_logits = self._naive_multicast(router_logits,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert:
|
||||
- All-reduce across DP
|
||||
- Slice to current rank's token range using cu_tokens_across_dp_cpu
|
||||
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
hidden_states = get_dp_group().all_reduce(
|
||||
hidden_states) # Sum across DP
|
||||
hidden_states = hidden_states[start:end, :]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
273
vllm_npu/ops/moe/moe_comm_method.py
Normal file
273
vllm_npu/ops/moe/moe_comm_method.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_npu.ascend_forward_context import MoECommType
|
||||
from vllm_npu.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_npu.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_npu.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
TokenDispatcherWithMoge)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
|
||||
def get_moe_comm_method(
|
||||
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
|
||||
return _MoECommMethods.get(moe_comm_type, None)
|
||||
|
||||
|
||||
def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||
moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.model_type = get_current_vllm_config(
|
||||
).model_config.hf_config.model_type
|
||||
self.moe_config = moe_config
|
||||
self.mc2_mask = None
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce)
|
||||
self.mc2_mask = mc2_mask
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
hidden_states = self.fused_moe_prepare_finalize.finalize(
|
||||
hidden_states, reduce_results)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=self.mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
||||
|
||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=expert_tokens,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=use_int8_w8a8
|
||||
or use_int4_w4a8,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
final_hidden_states = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (final_hidden_states, group_list_type, expert_tokens)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self):
|
||||
raise NotImplementedError(
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
raise NotImplementedError(
|
||||
"_get_fused_moe_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
if self.model_type == "PanguProMoE":
|
||||
return TokenDispatcherWithMoge(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
else:
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class NaiveMulticastCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
266
vllm_npu/ops/moe/moe_mlp.py
Normal file
266
vllm_npu/ops/moe/moe_mlp.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import pad
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_npu.ascend_forward_context import MoECommType
|
||||
from vllm_npu.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def cumsum_group_list(group_list: torch.Tensor,
|
||||
src_list_type: int,
|
||||
dst_list_type: int,
|
||||
active_num: int = 0,
|
||||
expert_num: int = 0) -> torch.Tensor:
|
||||
if src_list_type not in [0, 1, 2]:
|
||||
raise ValueError(
|
||||
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
|
||||
)
|
||||
|
||||
if src_list_type == dst_list_type:
|
||||
return group_list
|
||||
if src_list_type == 1 and dst_list_type == 0:
|
||||
return group_list.cumsum(dim=0)
|
||||
if src_list_type == 0 and dst_list_type == 1:
|
||||
group_diff = torch.diff(group_list)
|
||||
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
|
||||
return new_group
|
||||
if src_list_type == 2 and dst_list_type == 0:
|
||||
experts = pad(group_list[:, 0], (1, 0))
|
||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||
fill_value=active_num,
|
||||
dtype=group_list.dtype,
|
||||
device=group_list.device)
|
||||
|
||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||
if end > start:
|
||||
cumsum_group_list[start:end] = tokens[i]
|
||||
|
||||
return cumsum_group_list
|
||||
raise NotImplementedError(
|
||||
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
|
||||
"This feature is under development.")
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
fusion: bool = False,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||
hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
if fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale)
|
||||
else:
|
||||
if w1_scale.dtype != torch.float32:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=cumsum_group_list(group_list, group_list_type, 1),
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
if fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
bias=bias1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale)
|
||||
else:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale.to(w2_scale.dtype)],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
need_trans: bool = True) -> torch.Tensor:
|
||||
|
||||
if need_trans:
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False,
|
||||
fusion: bool = False,
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales,
|
||||
need_trans=need_trans)
|
||||
725
vllm_npu/ops/moe/token_dispatcher.py
Normal file
725
vllm_npu/ops/moe/token_dispatcher.py
Normal file
@@ -0,0 +1,725 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_npu.distributed.parallel_state import get_mc2_group
|
||||
from vllm_npu.ops.moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_npu.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.top_k = kwargs.get("top_k", 0)
|
||||
self.num_experts = kwargs.get("num_experts", 0)
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@abstractmethod
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (
|
||||
get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.shared_act = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
self.expand_scales = None
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
):
|
||||
quant_mode = 2 if self.with_quant else 0
|
||||
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"expert_token_nums_type": 0,
|
||||
}
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
topk_weights.to(torch.float32),
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
self.with_quant = with_quant
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
self.shared_experts = shared_experts
|
||||
self.mc2_mask = mc2_mask
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
global_redundant_expert_num)
|
||||
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
|
||||
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
}
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
|
||||
assert self.expert_map is not None
|
||||
assert self.topk_weights is not None
|
||||
assert self.topk_ids is not None
|
||||
assert self.output is not None
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
tp_recv_counts = self.output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
"expand_scales": self.expand_scales,
|
||||
}
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
self.assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": self.assist_info_for_combine,
|
||||
})
|
||||
if self.need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.expert_map = None
|
||||
self.expand_scales = None
|
||||
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
if self.with_quant:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
else:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
self.shared_act = None
|
||||
self.shared_experts = None
|
||||
self.swiglu_out_scale = None
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
self.num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.sorted_weights = None
|
||||
self.expanded_row_idx = None
|
||||
self.sorted_token_indices = None
|
||||
self.original_shape = None
|
||||
self.mask = None
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
mask = (expert_map[topk_ids] != -1)
|
||||
self.topk_weights = topk_weights * mask
|
||||
first_expert_idx = get_ep_group(
|
||||
).rank_in_group * self.num_experts_local
|
||||
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||
else:
|
||||
first_expert_idx = 0
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant else -1,
|
||||
))
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.original_shape is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=torch.abs(self.expanded_row_idx),
|
||||
probs=self.topk_weights)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.expanded_row_idx = None
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# mypy: disable-error-code="override"
|
||||
class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.local_num_experts = self.num_experts // self.ep_size
|
||||
self.local_num_group = self.top_k // self.ep_size
|
||||
self.bsz = None
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, self.sorted_topk_ids // self.local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
self.local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (
|
||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": group_list,
|
||||
"topk_scales": topk_scales,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
self.bsz, self.top_k // self.ep_size, -1).sum(1)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.with_quant = False
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
|
||||
self.hidden_shape = None
|
||||
self.topk_weights = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.tokens_per_expert = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.topk_weights = topk_weights
|
||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
|
||||
hidden_states, topk_ids)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
|
||||
dynamic_scale_after_all2all = None
|
||||
if self.with_quant:
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
permutated_local_input_tokens)
|
||||
|
||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||
dynamic_scale,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute2_ep_all_to_all_handle.wait()
|
||||
dynamic_scale.untyped_storage().resize_(0)
|
||||
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all)
|
||||
return {
|
||||
"hidden_states": global_input_tokens,
|
||||
"group_list": tokens_per_expert,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"group_list_type": 1
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
hidden_states = self._combine_preprocess(hidden_states)
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits,
|
||||
self.ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
output = self._combine_postprocess(permutated_local_input_tokens)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.topk_weights = None
|
||||
self.reversed_local_input_permutation_mapping = None
|
||||
self.reversed_global_input_permutation_mapping = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
return output
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self._preprocess(topk_ids)
|
||||
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=topk_ids,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = topk_ids.numel()
|
||||
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size,
|
||||
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
|
||||
non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum.")
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
|
||||
if self.num_local_experts > 1:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
else:
|
||||
# TODO: This full synchronization can be a performance bottleneck.
|
||||
# A more granular sync (e.g., blocking D2H copies) should be investigated.
|
||||
torch.npu.synchronize()
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
|
||||
# Early return if no local experts or no tokens
|
||||
if self.num_local_experts <= 1:
|
||||
return global_input_tokens, None
|
||||
|
||||
# Handle quantized case
|
||||
if self.with_quant:
|
||||
assert self.global_input_tokens_local_experts_indices is not None, \
|
||||
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
|
||||
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
|
||||
-1)
|
||||
active_num = self.global_input_tokens_local_experts_indices.numel()
|
||||
|
||||
# Handle case with no active tokens
|
||||
if active_num <= 0:
|
||||
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
|
||||
return global_input_tokens, dynamic_scale
|
||||
|
||||
# Process with active tokens
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
global_input_tokens,
|
||||
expert_idx_2d,
|
||||
scale=dynamic_scale,
|
||||
active_num=active_num,
|
||||
expert_capacity=0,
|
||||
expert_num=self.num_local_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[0, self.num_local_experts],
|
||||
quant_mode=-1,
|
||||
row_idx_type=0)
|
||||
return global_input_tokens, expanded_scale
|
||||
|
||||
# Handle non-quantized case
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
return global_input_tokens, None
|
||||
|
||||
def _combine_preprocess(self, hidden_states):
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states, self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_postprocess(self, permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.to(
|
||||
torch.int32),
|
||||
probs=self.topk_weights,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
Reference in New Issue
Block a user