Files
vllm-npu-plugin/vllm_npu/ops/moe/fused_moe_prepare_and_finalize.py
2026-02-10 23:08:39 +08:00

521 lines
22 KiB
Python

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_npu.utils import enable_sp
class FusedMoEPrepareAndFinalize(ABC):
"""
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
in distributed environments. Subclasses implement specific communication strategies
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
broadcasting, and reduction across TP/DP/EP groups.
Attributes:
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings.
"""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
- Slicing across tensor-parallel ranks
- Broadcasting across data-parallel ranks
Args:
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts
replace_allreduce (bool): Bypass default all-reduce behavior
Returns:
Tuple of:
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
- Reducing or scattering across DP ranks
- Unpadding to original token count
- Applying all-reduce across TP/EP if requested
Args:
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
Returns:
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
"""
raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using MC2 (Memory-Centric Communication).
Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""
Restore original TP configuration.
vLLM flattens TP and DP into a single dimension; this method recovers
the true TP world size and rank for correct tensor slicing.
"""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
2. Pad `hidden_states` and `router_logits` to target length if needed.
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
4. Split and return corresponding `mc2_mask`.
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1:
# Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
mc2_mask = split_mc2_mask[self.tp_rank]
if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
# Pad if necessary (unless shared expert DP is enabled)
if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states # Save for finalize
return hidden_states, router_logits, mc2_mask
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
2. Unpad to original token count if padding was applied.
3. Return tensor with shape [original_num_tokens, hidden_size].
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
# All-gather across TP group
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
# Unpad if necessary
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
There are two sets of prepare and finalize:
1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
we gather inputs across DP ranks before MoE, scatter outputs after.
The communication and calculation process is as follows (AG, AR and RS
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
Attn → TP AR → DP AG → MoE → DP RS → TP AR
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
the above process becomes:
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
TP AG → Attn → TP RS → EP AG → MoE → EP RS
"""
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits)
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
replace_allreduce)
def _prepare_with_ep_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
return hidden_states, router_logits, None
def _prepare_with_dp_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
2. Pad local tensors to that size.
3. All-gather across DP group to form global input tensor.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
Reduce Scatter hidden states.
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if enable_sp():
return self._finalize_with_ep_group(hidden_states)
return self._finalize_with_dp_group(hidden_states, reduce_results)
def _finalize_with_ep_group(self,
hidden_states: torch.Tensor) -> torch.Tensor:
"""
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
1. Reduce_results is False usually happens when models have shared experts and need to
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
result of shared experts.
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
here, then skip allreudce in FusedMoe.
"""
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states, True)
return hidden_states
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
2. Slice to original local token count.
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
Returns:
Tensor with shape [original_local_num_tokens, hidden_size]
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using Naive Multicast (point-to-point broadcast).
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
"""
def _naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
"""
Naive multicast implementation:
1. Create global buffer sized by total tokens across DP.
2. Current rank copies its slice into its designated buffer region.
3. Each rank broadcasts its slice to all others via P2P.
Args:
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
Returns:
torch.Tensor: Global tensor [total_tokens, hidden_size]
"""
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
# Copy local slice into buffer
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.moe_config.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
buffer[start:end, :].copy_(x)
# Broadcast each slice to all ranks
for idx in range(self.moe_config.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
self.cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu)
router_logits = self._naive_multicast(router_logits,
self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert:
- All-reduce across DP
- Slice to current rank's token range using cu_tokens_across_dp_cpu
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
self.moe_config.dp_rank - 1]
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
hidden_states = get_dp_group().all_reduce(
hidden_states) # Sum across DP
hidden_states = hidden_states[start:end, :]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states