diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 9ba604f..cc2d307 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -98,7 +98,7 @@ class TestTokenDispatcherWithMC2(TestBase): self.row_idx, expert_map) mock_dispatch.assert_called_once() self.assertEqual(output["group_list_type"], - 1) # group_list_type == 1 + 0) # group_list_type == 0 def test_token_dispatch_with_shared_experts_and_quant(self): self.shared_experts = MagicMock() diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 3cc8b95..b74f945 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -79,8 +79,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor, is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: - if w1_scale.dtype != torch.float32: - w1_scale = w1_scale.to(torch.float32) if fusion: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( @@ -90,6 +88,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, weight_scale=w1_scale, x_scale=pertoken_scale) else: + if w1_scale.dtype != torch.float32: + w1_scale = w1_scale.to(torch.float32) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index b6f908e..90c84d5 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -133,6 +133,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, + "expert_token_nums_type": 0, } stage1_kwargs = { @@ -204,7 +205,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): if shared_experts is not None: shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) self.shared_act = shared_experts.act_fn(shared_gate_up) - group_list_type = 1 + group_list_type = 0 return { "group_list_type": group_list_type, "hidden_states": expand_x,