Reapply "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483) (#3512)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.
3. Fix MC2 bug introduced in
https://github.com/vllm-project/vllm-ascend/pull/3365
4. Unify aclgraph & eager in W8A8_dynamic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-22 11:41:30 +08:00
committed by GitHub
parent 6ef62cb427
commit 2f1b9a7a64
13 changed files with 608 additions and 522 deletions

View File

@@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.distributed as dist
@@ -49,12 +50,15 @@ class FusedMoEPrepareAndFinalize(ABC):
is_deepseek_v3_r1)
@abstractmethod
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
@@ -74,11 +78,14 @@ class FusedMoEPrepareAndFinalize(ABC):
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
- optional context metadata (e.g., saved split_hidden_states for finalization)
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
@@ -96,9 +103,102 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using MC2 (Memory-Centric Communication).
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states}
return hidden_states, router_logits, None, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
assert context_metadata is not None
split_hidden_states = context_metadata["split_hidden_states"]
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
"""
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
All2All and share the same finalize method.
Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
@@ -116,12 +216,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
@@ -132,10 +235,11 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1:
@@ -165,124 +269,10 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states # Save for finalize
return hidden_states, router_logits, mc2_mask
context_metadata = {"split_hidden_states": split_hidden_states}
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
2. Unpad to original token count if padding was applied.
3. Return tensor with shape [original_num_tokens, hidden_size].
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
# All-gather across TP group
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
# Unpad if necessary
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
return hidden_states, router_logits, mc2_mask, context_metadata
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
@@ -307,12 +297,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
TP AG → Attn → TP RS → EP AG → MoE → EP RS
"""
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
@@ -331,21 +324,24 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
return hidden_states, router_logits, None
return hidden_states, router_logits, None, None
def _prepare_with_dp_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
@@ -353,7 +349,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
3. All-gather across DP group to form global input tensor.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
Tuple of (global_hidden_states, global_router_logits, None, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
@@ -377,11 +373,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
else:
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None, None
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
Reduce Scatter hidden states.
@@ -472,19 +469,22 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
Tuple of (global_hidden_states, global_router_logits, None, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -499,10 +499,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
router_logits = self._naive_multicast(
router_logits, self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None
return hidden_states, router_logits, None, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert: