[Fix] Adjust use_aclgraph logic (#2156)

### What this PR does / why we need it?
Updates the FusedMoE method to determine whether to use ACL Graph based
on the `torchair_graph_config`

This is equivalent to #2154 on v0.9.1-dev.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None needed.

- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-04 15:23:20 +08:00
committed by GitHub
parent 688350a3bb
commit a9480d5f0a
2 changed files with 11 additions and 2 deletions

View File

@@ -22,6 +22,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts)
from vllm_ascend.utils import is_310p
@@ -33,7 +34,15 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
def forward_oot(

View File

@@ -1105,7 +1105,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state