diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 6c89f6f..46192f6 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture): torch.randn(16, 2) )), \ patch("torch_npu.npu_grouped_matmul", return_value=( - (torch.randn(8, 2), torch.randn(8, 2)) + [torch.randn(16, 2)] )), \ patch("torch_npu.npu_swiglu", return_value=( torch.randn(16, 2) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f35fb10..f3c14bf 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -205,11 +205,9 @@ def fused_experts_with_mc2( group_list_type=1, group_type=0, group_list=group_list, - ) + )[0] - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( @@ -219,9 +217,7 @@ def fused_experts_with_mc2( group_list_type=1, group_type=0, group_list=group_list, - ) - - down_out_list = torch.cat(down_out_list, dim=0) + )[0] # moeCombine kwargs_mc2 = { @@ -312,9 +308,8 @@ def apply_mlp( group_list_type=group_list_type, group_type=0, group_list=group_list, - ) + )[0] - hidden_states = torch.cat(hidden_states, dim=0) hidden_states = torch_npu.npu_swiglu(hidden_states) w2 = w2.transpose(1, 2) @@ -325,9 +320,8 @@ def apply_mlp( group_list_type=group_list_type, group_type=0, group_list=group_list, - ) + )[0] - hidden_states = torch.cat(hidden_states, dim=0) return hidden_states @@ -417,23 +411,19 @@ def fused_experts_with_all2all( group_list_type=0, group_type=0, group_list=expert_tokens, - ) + )[0] - # TODO: Remove this in the future. - hidden_states = torch.cat(gate_up_out_list, dim=0) - hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, - ) - - hidden_states = torch.cat(down_out_list, dim=0) + )[0] if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) @@ -823,11 +813,9 @@ def fused_experts( group_list_type=0, group_type=0, group_list=expert_tokens, - ) + )[0] - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( @@ -837,9 +825,7 @@ def fused_experts( group_list_type=0, group_type=0, group_list=expert_tokens, - ) - - down_out_list = torch.cat(down_out_list, dim=0) + )[0] if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)