[Bugfix][MoE] enable force_load_balance in aclgraph (#4366)

### What this PR does / why we need it?
Temporarily fix the oom issue, will update to vllm's plan later.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e&ut

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-11-24 20:33:56 +08:00
committed by GitHub
parent ae068a3342
commit 5b1a7514eb

View File

@@ -19,7 +19,7 @@ from typing import Any, Callable, Optional
import torch
import torch_npu
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
@@ -54,21 +54,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe)
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.dynamic_eplb = get_ascend_config().dynamic_eplb
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE and
not vllm_config.model_config.enforce_eager)
self.transpose = True
def process_weights_after_loading(self, layer):
@@ -137,7 +123,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
moe_comm_method = get_forward_context().moe_comm_method