diff --git a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py index 6a204dc0..48777894 100644 --- a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py +++ b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py @@ -47,8 +47,8 @@ def test_generate_task_and_state_flow(mock_adaptor): loader_obj.state = loader.ExpertWeightUpdateState.WAITING loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0) - assert loader_obj.comm_op_list is None - assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING + assert not loader_obj.comm_op_list + assert loader_obj.state == loader.ExpertWeightUpdateState.READY def test_asyn_transfer_and_update(mock_adaptor): diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index e40f6708..5ff39c4e 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -26,7 +26,7 @@ class TestMoECommMethod(TestBase): self.moe_config.tp_size = 1 self.moe_config.ep_size = 1 self.moe_config.dp_group = MagicMock() - self.moe_config.num_global_redundant_experts = 0 + self.moe_config.global_redundant_expert_num = 0 @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 027815ba..d27da5cf 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -143,7 +143,7 @@ class TestTokenDispatcherWithMC2(TestBase): self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True - + self.dispatcher.moe_expert_num = len(expert_map) kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, context_metadata) self.assertIn("tp_send_counts", kwargs) diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 5c676cdd..ce1c3d73 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -50,10 +50,6 @@ class D2DExpertWeightLoader: ) return - # If neither send nor receive task is needed for this layer on this rank, return - if not (expert_send_info or expert_recv_info): - return - self.updated_expert_map = updated_expert_map self.layer_id = layer_id diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index efc709a3..f9534515 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -210,7 +210,7 @@ class AscendFusedMoE(FusedMoE): self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts - self.moe_config.original_num_experts = num_experts + self.moe_config.global_redundant_expert_num = self.global_redundant_expert_num moe_quant_params = { "num_experts": self.local_num_experts, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 06fd2fe4..07488f2d 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -114,7 +114,6 @@ class MoECommMethod(ABC): dynamic_scale_for_share: Optional[Any] = None, # For load balance log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, need_trans: bool = False, dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, @@ -133,7 +132,8 @@ class MoECommMethod(ABC): topk_ids=topk_ids, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, + global_redundant_expert_num=self.moe_config. + global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, @@ -290,7 +290,6 @@ class FusedMC2CommImpl(MoECommMethod): dynamic_scale_for_share: Optional[Any] = None, # For load balance log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, need_trans: bool = False, dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 0513307a..b40a0583 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -152,18 +152,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): mc2_mask: torch.Tensor, global_redundant_expert_num: int = 0, ): - if self.with_quant: - quant_mode = 2 - moe_expert_num = len(expert_map) - else: - quant_mode = 0 - moe_expert_num = len(expert_map) + quant_mode = 2 if self.with_quant else 0 + self.moe_expert_num = len(expert_map) + global_redundant_expert_num kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, + "moe_expert_num": self.moe_expert_num, "global_bs": self.global_bs, "expert_token_nums_type": 0, } @@ -253,7 +249,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): expand_scales = context_metadata["expand_scales"] assert expert_map is not None - moe_expert_num = len(expert_map) kwargs_mc2 = { "expand_x": hidden_states, @@ -261,7 +256,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, + "moe_expert_num": self.moe_expert_num, "global_bs": self.global_bs, } @@ -347,7 +342,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): hidden_states = hidden_states * \ topk_weights.to(hidden_states.dtype) if expert_map is not None: - global_num_experts = len(expert_map) + global_num_experts = len(expert_map) + global_redundant_expert_num mask = (expert_map[topk_ids] != -1) topk_weights = topk_weights * mask first_expert_idx = get_ep_group( diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/w4a16.py index 4fcc3380..c6eb379d 100644 --- a/vllm_ascend/quantization/w4a16.py +++ b/vllm_ascend/quantization/w4a16.py @@ -243,7 +243,6 @@ class AscendW4A16FusedMoEMethod: use_int4_w4a16=True, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 3222f2ea..167a42fc 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -391,7 +391,6 @@ class AscendW4A8DynamicFusedMoEMethod: use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index bebd807b..cba58850 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -279,7 +279,6 @@ class AscendW8A8DynamicFusedMoEMethod: use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share,