[MoE] [Refactor] Remove manual memory cleanup (#3365)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-15 12:36:24 +08:00
committed by GitHub
parent 4e720936d8
commit 4f937f561d
8 changed files with 562 additions and 492 deletions

View File

@@ -137,6 +137,7 @@ def test_token_dispatcher_with_all_gather(
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
@@ -144,8 +145,10 @@ def test_token_dispatcher_with_all_gather(
group_list=group_list,
group_list_type=group_list_type)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
bias=None)
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
@@ -215,6 +218,7 @@ def test_token_dispatcher_with_all_gather_quant(
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
@@ -225,8 +229,10 @@ def test_token_dispatcher_with_all_gather_quant(
group_list_type=group_list_type,
dynamic_scale=dynamic_scale,
with_quant=True)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
bias=None)
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k)
gc.collect()
torch.npu.empty_cache()