[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)
### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,16 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
in distributed environments. Subclasses implement specific communication strategies
|
||||
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
|
||||
broadcasting, and reduction across TP/DP/EP groups.
|
||||
|
||||
Attributes:
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
@@ -40,22 +50,65 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
- Slicing across tensor-parallel ranks
|
||||
- Broadcasting across data-parallel ranks
|
||||
- Recomputing router logits if needed
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
rm_router_logits (bool): Discard input router_logits and recompute via gate
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||
- processed router_logits (may be recomputed or broadcasted)
|
||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
- Reducing or scattering across DP ranks
|
||||
- Unpadding to original token count
|
||||
- Applying all-reduce across TP/EP if requested
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
|
||||
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
|
||||
"""
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using MC2 (Memory-Centric Communication).
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
"""
|
||||
Restore original TP configuration.
|
||||
vLLM flattens TP and DP into a single dimension; this method recovers
|
||||
the true TP world size and rank for correct tensor slicing.
|
||||
"""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@@ -66,9 +119,17 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""The target_pad_length is calculated in forward_context, here we pad the
|
||||
hidden states and router logits. And if TP size > 1, we also need to split
|
||||
the tensors accordingly.
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
2. Pad `hidden_states` and `router_logits` to target length if needed.
|
||||
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
|
||||
4. Split and return corresponding `mc2_mask`.
|
||||
|
||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
@@ -80,12 +141,14 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
if pad_size > 0 and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1:
|
||||
if not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
@@ -96,8 +159,9 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
@@ -107,16 +171,22 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
|
||||
2. Unpad to original token count if padding was applied.
|
||||
3. Return tensor with shape [original_num_tokens, hidden_size].
|
||||
|
||||
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
# All-gather across TP group
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# Unpad if necessary
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
@@ -124,14 +194,18 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
"""Restore original TP configuration (same as MC2)."""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@@ -142,12 +216,23 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
2. If TP > 1, split along token dim and select current TP rank's slice.
|
||||
3. Save splits for later all-gather in finalize.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
@@ -171,9 +256,13 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
2. Unpad to original token count.
|
||||
3. Return [original_num_tokens, hidden_size] tensor.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
@@ -188,6 +277,11 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter.
|
||||
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
Uses `max_tokens_across_dp` from forward_context for padding alignment.
|
||||
"""
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -196,8 +290,16 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""When DP size > 1, pad the hidden states and router logits for communication."""
|
||||
self.rm_router_logits = rm_router_logits
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
2. Pad local tensors to that size.
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
@@ -209,14 +311,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
if not self.rm_router_logits:
|
||||
if not rm_router_logits:
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
if rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
@@ -225,9 +328,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
2. Slice to original local token count.
|
||||
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
When TP size > 1, all-reduce the hidden states to get the final output.
|
||||
Returns:
|
||||
Tensor with shape [original_local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
@@ -238,3 +346,101 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using Naive Multicast (point-to-point broadcast).
|
||||
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
|
||||
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
|
||||
"""
|
||||
|
||||
def _naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
"""
|
||||
Naive multicast implementation:
|
||||
1. Create global buffer sized by total tokens across DP.
|
||||
2. Current rank copies its slice into its designated buffer region.
|
||||
3. Each rank broadcasts its slice to all others via P2P.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
|
||||
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Global tensor [total_tokens, hidden_size]
|
||||
"""
|
||||
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
# Copy local slice into buffer
|
||||
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
|
||||
# Broadcast each slice to all ranks
|
||||
for idx in range(self.moe_config.dp_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
2. Multicast hidden_states and router_logits to form global tensors.
|
||||
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
self.cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
if rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self._naive_multicast(
|
||||
router_logits, self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert:
|
||||
- All-reduce across DP
|
||||
- Slice to current rank's token range using cu_tokens_across_dp_cpu
|
||||
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
|
||||
self.moe_config.dp_rank - 1]
|
||||
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
|
||||
hidden_states = get_dp_group().all_reduce(
|
||||
hidden_states) # Sum across DP
|
||||
hidden_states = hidden_states[start:end, :]
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user