diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 29ef1fb..eeb8ec3 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -22,6 +22,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge, select_experts) from vllm_ascend.utils import is_310p @@ -33,7 +34,15 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs): original_unquantized_fused_moe_init_func(self, *args, **kwargs) vllm_config = get_current_vllm_config() self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager + + ascend_config = get_ascend_config() + + if ascend_config.torchair_graph_config.enabled: + self.use_aclgraph = False + else: + self.use_aclgraph = (vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager) def forward_oot( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index b2b1ab9..d4784d4 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1105,7 +1105,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. - if enable_force_load_balance: + if enable_force_load_balance and not self.use_aclgraph: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state