From 5b1a7514eb101c69e4925c22d83d02f3219f2d20 Mon Sep 17 00:00:00 2001 From: weichen Date: Mon, 24 Nov 2025 20:33:56 +0800 Subject: [PATCH] [Bugfix][MoE] enable force_load_balance in aclgraph (#4366) ### What this PR does / why we need it? Temporarily fix the oom issue, will update to vllm's plan later. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e&ut - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: Pr0Wh1teGivee --- vllm_ascend/ops/fused_moe/fused_moe.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 4788c87d..c78764c4 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -19,7 +19,7 @@ from typing import Any, Callable, Optional import torch import torch_npu -from vllm.config import CompilationMode, get_current_vllm_config +from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context @@ -54,21 +54,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, moe: FusedMoEConfig = None): super().__init__(moe=moe) - - # NOTE: Currently, this self.use_aclgraph is only used in - # UnquantizedFusedMoEMethod.forward_oot to decide whether to use in - # ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue. - # Once torch.randint_like is supported or removed, this flag can be removed. - vllm_config = get_current_vllm_config() - ascend_config = get_ascend_config() self.dynamic_eplb = get_ascend_config().dynamic_eplb - if ascend_config.torchair_graph_config.enabled: - self.use_aclgraph = False - else: - self.use_aclgraph = (vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE and - not vllm_config.model_config.enforce_eager) - self.transpose = True def process_weights_after_loading(self, layer): @@ -137,7 +123,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. - if enable_force_load_balance and not self.use_aclgraph: + if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) moe_comm_method = get_forward_context().moe_comm_method