Reapply "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483) (#3512)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.
3. Fix MC2 bug introduced in
https://github.com/vllm-project/vllm-ascend/pull/3365
4. Unify aclgraph & eager in W8A8_dynamic.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-22 11:41:30 +08:00
committed by GitHub
parent 6ef62cb427
commit 2f1b9a7a64
13 changed files with 608 additions and 522 deletions

View File

@@ -144,7 +144,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map=expert_map,
shared_experts=shared_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb)
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
class AscendFusedMoE(FusedMoE):
@@ -305,7 +306,7 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context()
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled,
@@ -333,7 +334,8 @@ class AscendFusedMoE(FusedMoE):
shared_experts=None,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num)
global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
@@ -344,7 +346,8 @@ class AscendFusedMoE(FusedMoE):
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
reduce_results=self.reduce_results)
reduce_results=self.reduce_results,
context_metadata=context_metadata)
return final_hidden_states