diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index c57146b..f028756 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -264,6 +264,10 @@ class AscendFusedMoE(FusedMoE): quantized_x_for_share, dynamic_scale_for_share = None, None forward_context = get_forward_context() + + # Load balancing for token distribution among experts in dummy_run + # TODO: The community only considers load balancing when DP > 1. + # This approach may overlook some extreme scenarios. enable_force_load_balance = forward_context.in_profile_run forward_context = get_forward_context() diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index ab4987f..6701f70 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -217,6 +217,12 @@ class AscendW8A8DynamicFusedMoEMethod: e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + if self.use_aclgraph: moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts( @@ -232,12 +238,6 @@ class AscendW8A8DynamicFusedMoEMethod: expert_map=expert_map, dynamic_eplb=self.dynamic_eplb) - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - topk_weights = topk_weights.to(x.dtype) moe_comm_method = get_forward_context().moe_comm_method