diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index 8faa3bb2..d1981b2a 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -294,13 +294,13 @@ class TestCumsumGroupList(TestBase): def test_cumsum_group_list_with_type_0(self): group_list = self.experts.cumsum(dim=0) group_list_type = 0 - result = cumsum_group_list(group_list, group_list_type) + result = cumsum_group_list(group_list, group_list_type, 0) self.assertTrue(torch.equal(result, self.group_list)) def test_cumsum_group_list_with_type_1(self): group_list = self.experts group_list_type = 1 - result = cumsum_group_list(group_list, group_list_type) + result = cumsum_group_list(group_list, group_list_type, 0) self.assertTrue(torch.equal(result, self.group_list)) def test_cumsum_group_list_with_type_2(self): @@ -313,6 +313,7 @@ class TestCumsumGroupList(TestBase): group_list_type = 2 result = cumsum_group_list(group_list, group_list_type, + 0, active_num=self.active_num, expert_num=self.expert_num) self.assertTrue(torch.equal(result, self.group_list)) @@ -592,4 +593,4 @@ class TestUnifiedApplyMLP(TestBase): self.assertTrue(mock_forward_context.with_quant) self.assertEqual(result.shape, hidden_states_shape) - self.assertEqual(result.dtype, torch.bfloat16) + self.assertEqual(result.dtype, torch.bfloat16) \ No newline at end of file diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 3fc12644..c2ba8037 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -31,31 +31,39 @@ def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): def cumsum_group_list(group_list: torch.Tensor, - group_list_type: int, + src_list_type: int, + dst_list_type: int, active_num: int = 0, expert_num: int = 0) -> torch.Tensor: - if group_list_type not in [0, 1, 2]: + if src_list_type not in [0, 1, 2]: raise ValueError( - f"group_list_type should be in [0, 1, 2], but received {group_list_type}" + f"group_list_type should be in [0, 1, 2], but received {src_list_type}" ) - if group_list_type == 0: + if src_list_type == dst_list_type: return group_list - if group_list_type == 1: + if src_list_type == 1 and dst_list_type == 0: return group_list.cumsum(dim=0) + if src_list_type == 0 and dst_list_type == 1: + group_diff = torch.diff(group_list) + new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) + return new_group + if src_list_type == 2 and dst_list_type == 0: + experts = pad(group_list[:, 0], (1, 0)) + tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) + cumsum_group_list = torch.full(size=(expert_num, ), + fill_value=active_num, + dtype=group_list.dtype, + device=group_list.device) - experts = pad(group_list[:, 0], (1, 0)) - tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) - cumsum_group_list = torch.full(size=(expert_num, ), - fill_value=active_num, - dtype=group_list.dtype, - device=group_list.device) + for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): + if end > start: + cumsum_group_list[start:end] = tokens[i] - for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): - if end > start: - cumsum_group_list[start:end] = tokens[i] - - return cumsum_group_list + return cumsum_group_list + raise NotImplementedError( + f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. " + "This feature is under development.") def quant_apply_mlp(hidden_states: torch.Tensor, @@ -106,14 +114,15 @@ def quant_apply_mlp(hidden_states: torch.Tensor, weight=w1, weight_scale=w1_scale, x_scale=pertoken_scale, - group_list=cumsum_group_list(group_list, group_list_type), + group_list=cumsum_group_list(group_list, group_list_type, + 0), )) elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, weight=w1[0], - group_list=cumsum_group_list(group_list, group_list_type), + group_list=cumsum_group_list(group_list, group_list_type, 0), weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: @@ -140,7 +149,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias=None, quant_scale=None, quant_offset=None, - group_index=group_list, + group_index=cumsum_group_list(group_list, group_list_type, 1), activate_left=True, quant_mode=1, ) @@ -202,7 +211,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, weight=w1, weight_scale=w1_scale, x_scale=pertoken_scale, - group_list=cumsum_group_list(group_list, group_list_type), + group_list=cumsum_group_list(group_list, group_list_type, + 0), bias=bias1, )) elif fusion and not dynamic_eplb: @@ -211,7 +221,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, x=hidden_states, weight=w1[0], bias=bias1, - group_list=cumsum_group_list(group_list, group_list_type), + group_list=cumsum_group_list(group_list, group_list_type, 0), weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: