From cec1fab5099dde2cf9247a1993c3e97761d9f7ab Mon Sep 17 00:00:00 2001 From: weichen <132029610+Pr0Wh1teGivee@users.noreply.github.com> Date: Wed, 15 Oct 2025 22:25:46 +0800 Subject: [PATCH] Revert "[MoE] [Refactor] Remove manual memory cleanup (#3365)" (#3483) This reverts commit 4f937f561d573ae97f953169865cfbf70d0c220b. ### What this PR does / why we need it? This reverts commit 4f937f561d573ae97f953169865cfbf70d0c220b. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Pr0Wh1teGivee --- tests/e2e/singlecard/ops/test_fused_moe.py | 14 +- .../test_fused_moe_prepare_and_finalize.py | 56 +-- tests/ut/ops/test_moe_comm_method.py | 28 +- tests/ut/ops/test_token_dispatcher.py | 126 ++--- vllm_ascend/ops/common_fused_moe.py | 8 +- .../ops/moe/fused_moe_prepare_and_finalize.py | 326 ++++++------ vllm_ascend/ops/moe/moe_comm_method.py | 40 +- vllm_ascend/ops/moe/token_dispatcher.py | 474 +++++++++--------- 8 files changed, 500 insertions(+), 572 deletions(-) diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index fae3ecb..4735a5f 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -137,7 +137,6 @@ def test_token_dispatcher_with_all_gather( sorted_hidden_states = dispatch_output["hidden_states"] group_list = dispatch_output["group_list"] group_list_type = dispatch_output.get("group_list_type", 1) - context_metadata = dispatch_output["context_metadata"] expert_output = apply_mlp(hidden_states=sorted_hidden_states, w1=w1_local, @@ -145,10 +144,8 @@ def test_token_dispatcher_with_all_gather( group_list=group_list, group_list_type=group_list_type) - combined_output = dispatcher.token_combine( - hidden_states=expert_output, - context_metadata=context_metadata, - bias=None) + combined_output = dispatcher.token_combine(hidden_states=expert_output, + bias=None) torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map) @@ -218,7 +215,6 @@ def test_token_dispatcher_with_all_gather_quant( group_list = dispatch_output["group_list"] group_list_type = dispatch_output.get("group_list_type", 1) dynamic_scale = dispatch_output["dynamic_scale"] - context_metadata = dispatch_output["context_metadata"] expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, w1=w1, @@ -229,10 +225,8 @@ def test_token_dispatcher_with_all_gather_quant( group_list_type=group_list_type, dynamic_scale=dynamic_scale, with_quant=True) - combined_output = dispatcher.token_combine( - hidden_states=expert_output, - context_metadata=context_metadata, - bias=None) + combined_output = dispatcher.token_combine(hidden_states=expert_output, + bias=None) assert combined_output.shape == (m, k) gc.collect() torch.npu.empty_cache() diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py index 93b73ec..3a9733b 100644 --- a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -44,8 +44,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mask, context_metadata = layer.prepare( - hidden_states, router_logits) + h_out, r_out, mask = layer.prepare(hidden_states, router_logits) # Check padding and split self.assertEqual(h_out.shape[0], 4) @@ -53,9 +52,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): self.assertEqual(mask.tolist(), [1, 0, 1]) # Finalize - result = layer.finalize(h_out, - reduce_results=False, - context_metadata=context_metadata) + result = layer.finalize(h_out, reduce_results=False) self.assertEqual(result.shape[0], 3) @patch( @@ -80,11 +77,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(4, 8) router_logits = torch.randn(4, 2) - h_out, r_out, mask, context_metadata = layer.prepare( - hidden_states, - router_logits, - enable_shared_expert_dp=False, - replace_allreduce=False) + h_out, r_out, mask = layer.prepare(hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) # With TP=2, should split into 2 parts self.assertEqual(h_out.shape[0], 2) @@ -100,9 +96,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): torch.zeros_like(h_out), torch.zeros_like(h_out) ] - final_result = layer.finalize(h_out, - reduce_results=False, - context_metadata=context_metadata) + final_result = layer.finalize(h_out, reduce_results=False) # Should concat back to original size self.assertEqual(final_result.shape[0], 4) @@ -118,15 +112,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, _, context_metadata = layer.prepare( - hidden_states, router_logits) + h_out, r_out, _ = layer.prepare(hidden_states, router_logits) # Pad to tp_size=1, so no change self.assertEqual(h_out.shape[0], 3) - result = layer.finalize(h_out, - reduce_results=False, - context_metadata=context_metadata) + result = layer.finalize(h_out, reduce_results=False) self.assertEqual(result.shape[0], 3) @patch( @@ -142,11 +133,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(2, 8) router_logits = torch.randn(2, 2) - h_out, r_out, _, context_metadata = layer.prepare( - hidden_states, - router_logits, - enable_shared_expert_dp=False, - replace_allreduce=False) + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) # Split due to TP=2 self.assertEqual(h_out.shape[0], 1) @@ -162,9 +152,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): torch.zeros_like(h_out), torch.zeros_like(h_out) ] - final_result = layer.finalize(h_out, - reduce_results=False, - context_metadata=context_metadata) + final_result = layer.finalize(h_out, reduce_results=False) # Should concat back self.assertEqual(final_result.shape[0], 2) @@ -207,9 +195,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): mock_gate = MagicMock() mock_gate.return_value = (router_logits.repeat(2, 1), None) - h_out, r_out, _, context_metadata = layer.prepare(hidden_states, - router_logits, - gate=mock_gate) + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + gate=mock_gate) # After all-gather with DP=2, should double the batch size self.assertEqual(h_out.shape[0], 12) @@ -221,9 +209,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): return tensor[:3] mock_dp_group.reduce_scatter = mock_reduce_scatter_func - result = layer.finalize(h_out, - reduce_results=False, - context_metadata=context_metadata) + result = layer.finalize(h_out, reduce_results=False) self.assertEqual(result.shape[0], 3) @@ -277,9 +263,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): mock_gate.return_value = (torch.randn(7, 2), None) # Run prepare - h_out, r_out, _, _ = layer.prepare(hidden_states, - router_logits, - gate=mock_gate) + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + gate=mock_gate) # Should be global tensor: [7, 8] and [7, 2] self.assertEqual(h_out.shape, (7, 8)) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index a3ef441..3826a19 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -45,7 +45,7 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None, None) + torch.randn(4, 2), None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -59,18 +59,15 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( - hidden_states, router_logits) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, None) # Test finalize method - comm_impl.finalize(h_out, - reduce_results=True, - context_metadata=context_metadata) - mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) + comm_impl.finalize(h_out, reduce_results=True) + mock_pf_instance.finalize.assert_called_once_with(h_out, True) @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @@ -93,8 +90,7 @@ class TestMoECommMethod(TestBase): mock_pf_instance = MagicMock() mock_pf_instance.prepare.return_value = (torch.randn(4, 8), torch.randn(4, 2), - torch.tensor([1, 0, 1, - 0]), None) + torch.tensor([1, 0, 1, 0])) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -108,18 +104,15 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( - hidden_states, router_logits) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, None) # Test finalize method - comm_impl.finalize(h_out, - reduce_results=True, - context_metadata=context_metadata) - mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) + comm_impl.finalize(h_out, reduce_results=True) + mock_pf_instance.finalize.assert_called_once_with(h_out, True) @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @@ -142,7 +135,7 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None, None) + torch.randn(4, 2), None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -156,8 +149,7 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( - hidden_states, router_logits) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 486696c..87f384f 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -77,10 +77,9 @@ class TestTokenDispatcherWithMC2(TestBase): topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - mc2_mask = None kwargs = self.dispatcher.get_dispatch_mc2_kwargs( - hidden_states, topk_weights, topk_ids, expert_map, mc2_mask) + hidden_states, topk_weights, topk_ids, expert_map) self.assertIn("x", kwargs) self.assertIn("expert_ids", kwargs) self.assertEqual(kwargs["moe_expert_num"], 8) @@ -124,64 +123,36 @@ class TestTokenDispatcherWithMC2(TestBase): def test_get_combine_mc_kwargs_with_quant(self): self.dispatcher.with_quant = True hidden_states = torch.randn(10, 128) - topk_ids = torch.randint(0, 8, (10, 1)) - topk_weights = torch.randn(10, 1) # 注意:应为 float,不是 int - expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - mc2_mask = None - assist_info_for_combine = torch.arange(10) # mock 值 - - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "mc2_mask": mc2_mask, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": None, - } - + self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) + self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True + self.dispatcher.output = torch.randint(0, 8, (10, 1)) - kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, - context_metadata) + kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states) self.assertIn("tp_send_counts", kwargs) def test_token_combine_with_shared_experts(self): - shared_experts = MagicMock() - shared_experts.down_proj.return_value = (torch.randn(10, 128), - torch.tensor(1.0)) - - topk_ids = torch.randint(0, 8, (10, 1)) - topk_weights = torch.randn(10, 1) - expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - assist_info_for_combine = torch.arange(10) - - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "mc2_mask": None, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": None, - "shared_experts": shared_experts, - "shared_act": torch.randn(10, 128), - "swiglu_out_scale": torch.randn(10, 1), - } - + self.dispatcher.shared_experts = MagicMock() + self.dispatcher.shared_experts.down_proj.return_value = (torch.randn( + 10, 128), torch.tensor(1.0)) + self.dispatcher.shared_act = torch.randn(10, 128) self.dispatcher.with_quant = True + self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) + self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True + self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1)) + self.dispatcher.output = torch.randint(0, 8, (10, 1)) + self.hidden_states = torch.randn(10, 128) - hidden_states = torch.randn(10, 128) with patch("torch_npu.npu_moe_distribute_combine_v2", return_value=torch.randn(10, 128)): - result = self.dispatcher.token_combine(hidden_states, - context_metadata) - self.assertIsInstance(result, tuple) + self.dispatcher.token_combine(self.hidden_states) class TestTokenDispatcherWithAllGather(TestBase): @@ -293,26 +264,35 @@ class TestTokenDispatcherWithAllGather(TestBase): self.assertEqual(results["group_list_type"], 1) def test_token_combine_with_expert_map(self): + self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) + self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.sorted_weights = torch.tensor( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + self.dispatcher.original_shape = (3, 128) + self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) - context_metadata = { - "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), - "topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), - } - self.dispatcher.original_shape = (6, 128) - final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata) + + final_hidden_states = self.dispatcher.token_combine(hidden_states) self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_combine_without_expert_map(self): + self.dispatcher.with_quant = False + self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) + self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) + self.dispatcher.sorted_weights = torch.tensor( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + self.dispatcher.original_shape = (3, 128) + self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) - context_metadata = { - "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), - "topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), - } - self.dispatcher.original_shape = (6, 128) - final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata) + + final_hidden_states = self.dispatcher.token_combine(hidden_states) + + # Verify npu_moe_finalize_routing is called self.mock_npu_moe_token_unpermute.assert_called_once() + args, kwargs = self.mock_npu_moe_token_unpermute.call_args + self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_dispatch_with_router_weight(self): @@ -438,21 +418,25 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.assertEqual(result["group_list_type"], 1) def test_token_combine(self): - hidden_states = torch.randn(16, 16) - context_metadata = { - "input_splits": [4, 4], - "output_splits": [4, 4], - "topk_weights": torch.rand(8, 4), - "reversed_local_input_permutation_mapping": torch.arange(8), - "reversed_global_input_permutation_mapping": torch.arange(16), - } self.dispatcher.hidden_shape = (8, 16) self.dispatcher.hidden_shape_before_permute = (8, 16) + self.dispatcher.reversed_local_input_permutation_mapping = torch.arange( + 8) + self.dispatcher.topk_weights = torch.rand(8, 4) + self.dispatcher.input_splits = [4, 4] + self.dispatcher.output_splits = [4, 4] + self.dispatcher.reversed_global_input_permutation_mapping = torch.arange( + 16) + self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] + self.dispatcher.num_global_tokens_per_local_expert = torch.tensor( + [[2, 2], [2, 2]], dtype=torch.int64) + + expert_output = torch.randn(16, 16) + output = self.dispatcher.token_combine(expert_output) - output = self.dispatcher.token_combine(hidden_states, context_metadata) self.assertIsNotNone(output) self.assertEqual(output.shape, (8, 16)) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 78183ac..bac07b2 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -301,7 +301,7 @@ class AscendFusedMoE(FusedMoE): enable_force_load_balance = forward_context.in_profile_run forward_context = get_forward_context() - hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( + hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, replace_allreduce=forward_context.sp_enabled, @@ -329,8 +329,7 @@ class AscendFusedMoE(FusedMoE): shared_experts=None, enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, - global_redundant_expert_num=self.global_redundant_expert_num, - mc2_mask=mc2_mask) + global_redundant_expert_num=self.global_redundant_expert_num) if isinstance(final_hidden_states, tuple): final_hidden_states, group_list_type, expert_tokens = final_hidden_states @@ -341,8 +340,7 @@ class AscendFusedMoE(FusedMoE): final_hidden_states = forward_context.moe_comm_method.finalize( hidden_states=final_hidden_states, - reduce_results=self.reduce_results, - context_metadata=context_metadata) + reduce_results=self.reduce_results) return final_hidden_states diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index 7533cce..19e4989 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -15,7 +15,6 @@ # This file is a part of the vllm-ascend project. from abc import ABC, abstractmethod -from typing import Optional import torch import torch.distributed as dist @@ -50,15 +49,12 @@ class FusedMoEPrepareAndFinalize(ABC): is_deepseek_v3_r1) @abstractmethod - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Prepare tensors before MoE computation. May involve: - Padding to align communication boundaries @@ -78,14 +74,11 @@ class FusedMoEPrepareAndFinalize(ABC): - processed hidden_states (may be padded/sliced/broadcasted) - processed router_logits (may be recomputed or broadcasted) - optional communication mask (e.g., mc2_mask for sparse ops) - - optional context metadata (e.g., saved split_hidden_states for finalization) """ raise NotImplementedError("Prepare not implemented.") - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: """ Finalize MoE output. May involve: - Gathering sliced tensors across TP ranks @@ -103,102 +96,9 @@ class FusedMoEPrepareAndFinalize(ABC): raise NotImplementedError("Finalize function not implemented.") -class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): +class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): """ - MoE communication strategy using All-to-All style slicing. - Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. - Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). - """ - - def __init__(self, moe_config: FusedMoEConfig): - super().__init__(moe_config) - self._restore_tp_across_dp() - - def _restore_tp_across_dp(self): - """Restore original TP configuration (same as MC2).""" - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: - """ - Preparation steps: - 1. Pad hidden_states and router_logits to next multiple of TP size. - 2. If TP > 1, split along token dim and select current TP rank's slice. - 3. Save splits for later all-gather in finalize. - - Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. - - Returns: - Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All. - """ - self.replace_allreduce = replace_allreduce - self.enable_shared_expert_dp = enable_shared_expert_dp - split_hidden_states = None - - if not (self.replace_allreduce or self.enable_shared_expert_dp): - self.num_tokens, _ = hidden_states.shape - pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) - - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - - context_metadata = {"split_hidden_states": split_hidden_states} - - return hidden_states, router_logits, None, context_metadata - - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: - """ - Finalization steps: - 1. If TP > 1, all-gather slices to reconstruct full tensor. - 2. Unpad to original token count. - 3. Return [original_num_tokens, hidden_size] tensor. - - Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. - """ - assert context_metadata is not None - - split_hidden_states = context_metadata["split_hidden_states"] - if not (self.enable_shared_expert_dp or self.replace_allreduce): - if self.tp_size > 1: - dist.all_gather(list(split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(split_hidden_states, dim=0) - - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states - - -class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): - """ - MoE communication strategy using MC2, which is based on All2All. Hence, it inherits - All2All and share the same finalize method. + MoE communication strategy using MC2 (Memory-Centric Communication). Designed for Ascend or environments requiring explicit padding and slicing control. Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. """ @@ -216,15 +116,12 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Fetch `mc2_mask` and target padding length from forward context. @@ -235,11 +132,10 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. Returns: - Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded. + Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded. """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp - split_hidden_states = None forward_context = get_forward_context() mc2_mask = forward_context.mc2_mask if self.tp_size > 1: @@ -269,10 +165,124 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): dim=0) hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] + self.split_hidden_states = split_hidden_states # Save for finalize - context_metadata = {"split_hidden_states": split_hidden_states} + return hidden_states, router_logits, mc2_mask - return hidden_states, router_logits, mc2_mask, context_metadata + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor. + 2. Unpad to original token count if padding was applied. + 3. Return tensor with shape [original_num_tokens, hidden_size]. + + Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + # All-gather across TP group + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + # TODO: It is a quick bugfix for the memory explosion issue in eager mode. + # If the cache is not cleared after `self.split_hidden_states` is created, + # it can lead to the memory explosion in eager mode. + del self.split_hidden_states + + # Unpad if necessary + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-to-All style slicing. + Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. + Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). + """ + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + """Restore original TP configuration (same as MC2).""" + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Pad hidden_states and router_logits to next multiple of TP size. + 2. If TP > 1, split along token dim and select current TP rank's slice. + 3. Save splits for later all-gather in finalize. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, None) — no mask used in All2All. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + + if not (self.replace_allreduce or self.enable_shared_expert_dp): + self.num_tokens, _ = hidden_states.shape + pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + self.split_hidden_states = split_hidden_states + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices to reconstruct full tensor. + 2. Unpad to original token count. + 3. Return [original_num_tokens, hidden_size] tensor. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + # TODO: It is a quick bugfix for the memory explosion issue in eager mode. + # If the cache is not cleared after `self.split_hidden_states` is created, + # it can lead to the memory explosion in eager mode. + del self.split_hidden_states + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): @@ -297,15 +307,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): TP AG → Attn → TP RS → EP AG → MoE → EP RS """ - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: AllGather hidden_states and router_logits to form global tensors. @@ -324,24 +331,21 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states, True, True) router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( router_logits, True, True) - return hidden_states, router_logits, None, None + return hidden_states, router_logits, None def _prepare_with_dp_group( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Fetch max token count across DP group from forward context. @@ -349,7 +353,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): 3. All-gather across DP group to form global input tensor. Returns: - Tuple of (global_hidden_states, global_router_logits, None, None) + Tuple of (global_hidden_states, global_router_logits, None) """ self.enable_shared_expert_dp = enable_shared_expert_dp if self.moe_config.dp_size > 1: @@ -373,12 +377,11 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): else: router_logits = self.moe_config.dp_group.all_gather( router_logits, 0) - return hidden_states, router_logits, None, None - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: """ Finalization steps: Reduce Scatter hidden states. @@ -469,22 +472,19 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): get_dp_group().broadcast(buffer[start:end, :], idx) return buffer - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preparation steps: 1. Fetch cumulative token boundaries from forward context. 2. Multicast hidden_states and router_logits to form global tensors. Returns: - Tuple of (global_hidden_states, global_router_logits, None, None) + Tuple of (global_hidden_states, global_router_logits, None) """ self.enable_shared_expert_dp = enable_shared_expert_dp @@ -499,12 +499,10 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): router_logits = self._naive_multicast( router_logits, self.cu_tokens_across_dp_cpu) - return hidden_states, router_logits, None, None + return hidden_states, router_logits, None - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: """ Finalization steps: 1. If DP > 1 and not shared expert: diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index a836443..d1d0c1a 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -57,31 +57,28 @@ class MoECommMethod(ABC): self.model_type = get_current_vllm_config( ).model_config.hf_config.model_type self.moe_config = moe_config + self.mc2_mask = None self.token_dispatcher = self._get_token_dispatcher() self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( ) - def prepare( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: - hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare( + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, gate) - return hidden_states, router_logits, mc2_mask, context_metadata + self.mc2_mask = mc2_mask + return hidden_states, router_logits - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: hidden_states = self.fused_moe_prepare_finalize.finalize( - hidden_states, reduce_results, context_metadata) + hidden_states, reduce_results) return hidden_states def fused_experts( @@ -111,8 +108,7 @@ class MoECommMethod(ABC): log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, need_trans: bool = False, - dynamic_eplb: bool = False, - mc2_mask: torch.Tensor = None): + dynamic_eplb: bool = False): # Check constraints assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 @@ -131,12 +127,12 @@ class MoECommMethod(ABC): shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, - mc2_mask=mc2_mask, + mc2_mask=self.mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8) - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ - results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, w1=w1, @@ -156,7 +152,7 @@ class MoECommMethod(ABC): dynamic_eplb=dynamic_eplb) final_hidden_states = self.token_dispatcher.token_combine( - hidden_states=mlp_output, context_metadata=context_metadata) + hidden_states=mlp_output) if dynamic_eplb: return (final_hidden_states, group_list_type, expert_tokens) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 9e4f220..3dd799a 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -75,7 +75,6 @@ class MoETokenDispatcher(ABC): @abstractmethod def token_combine(self, hidden_states: torch.Tensor, - context_metadata: dict, bias: torch.Tensor = None): raise NotImplementedError("Combine function not implemented.") @@ -103,7 +102,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. self.need_expert_scale = is_hierarchical_communication_enabled() + self.output = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.shared_act = None + self.topk_ids = None + self.topk_weights = None + self.shared_experts = None + self.mc2_mask = None self.with_quant = False + self.expand_scales = None def get_dispatch_mc2_kwargs( self, @@ -111,7 +119,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, - mc2_mask: torch.Tensor, global_redundant_expert_num: int = 0, ): if self.with_quant: @@ -148,7 +155,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): }) if self.a3_need_extra_args and self.enable_dispatch_v2: stage1_kwargs.update({ - "x_active_mask": mc2_mask, + "x_active_mask": self.mc2_mask, }) if self.need_expert_scale: stage1_kwargs.update({ @@ -159,121 +166,99 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - ): + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False): self.with_quant = with_quant - - # Apply log2phy if needed - if log2phy is not None: - topk_ids = log2phy[topk_ids] + self.expert_map = expert_map + self.topk_ids = topk_ids + self.topk_weights = topk_weights + self.shared_experts = shared_experts + self.mc2_mask = mc2_mask kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, topk_ids, expert_map, - mc2_mask, global_redundant_expert_num) - output = torch_npu.npu_moe_distribute_dispatch_v2( + self.output = torch_npu.npu_moe_distribute_dispatch_v2( **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ - ep_recv_counts, _, expand_scales = output[0:7] + expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \ + self.ep_recv_counts, _, self.expand_scales = self.output[0:7] - # Handle shared experts (store intermediate results in local vars, not self) - shared_act = None - swiglu_out_scale = None - if with_quant: + if self.with_quant: if shared_experts is not None: share_up_out, _ = shared_experts.gate_up_proj( (quantized_x_for_share, dynamic_scale_for_share)) shared_gate_up, shared_dequant_scale = share_up_out[ 0], share_up_out[1] + shared_act_out = shared_experts.act_fn( (shared_gate_up, shared_dequant_scale)) - shared_act, swiglu_out_scale = shared_act_out[ - 0], shared_act_out[1] + self.shared_act, self.swiglu_out_scale = \ + shared_act_out[0], shared_act_out[1] + else: if shared_experts is not None: shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - shared_act = shared_experts.act_fn(shared_gate_up) - - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "mc2_mask": mc2_mask, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "assist_info_for_combine": assist_info_for_combine, - "shared_experts": shared_experts, - "shared_act": shared_act, - "swiglu_out_scale": swiglu_out_scale, - "expand_scales": expand_scales - } - + self.shared_act = shared_experts.act_fn(shared_gate_up) + group_list_type = 0 return { - "group_list_type": 0, + "group_list_type": group_list_type, "hidden_states": expand_x, "group_list": expert_token_nums, "dynamic_scale": dynamic_scale, - "context_metadata": context_metadata } - def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, - context_metadata: dict): - expert_map = context_metadata["expert_map"] - topk_ids = context_metadata["topk_ids"] - topk_weights = context_metadata["topk_weights"] - ep_recv_counts = context_metadata["ep_recv_counts"] - assist_info_for_combine = context_metadata["assist_info_for_combine"] - mc2_mask = context_metadata["mc2_mask"] - expand_scales = context_metadata["expand_scales"] - - assert expert_map is not None - moe_expert_num = len(expert_map) - + def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): + assert self.expert_map is not None + assert self.topk_weights is not None + assert self.topk_ids is not None + assert self.output is not None + moe_expert_num = len(self.expert_map) + # moeCombine kwargs_mc2 = { "expand_x": hidden_states, - "expert_ids": topk_ids, - "expert_scales": topk_weights.to(torch.float32), + "expert_ids": self.topk_ids, + "expert_scales": self.topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, } - if self.with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) else: - tp_recv_counts = ep_recv_counts - + tp_recv_counts = self.output[5] stage3_kwargs = { - "ep_send_counts": ep_recv_counts, + "ep_send_counts": self.ep_recv_counts, "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, - "expand_scales": expand_scales, + "expand_scales": self.expand_scales, } - if self.enable_dispatch_v2: - stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine + stage3_kwargs.update({ + "assist_info_for_combine": + self.assist_info_for_combine, + }) else: - stage3_kwargs["expand_idx"] = assist_info_for_combine - + stage3_kwargs.update({ + "expand_idx": self.assist_info_for_combine, + }) if self.need_extra_args: stage3_kwargs.update({ "tp_send_counts": tp_recv_counts, @@ -281,40 +266,45 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "tp_world_size": 1, "tp_rank_id": 0, }) - if self.a3_need_extra_args and self.enable_dispatch_v2: - stage3_kwargs["x_active_mask"] = mc2_mask - + stage3_kwargs.update({ + "x_active_mask": self.mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 - def token_combine( - self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None, - ): - assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) - kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, - context_metadata) - combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ - if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + # these values are no longer used, so they need to be set to None for memory release. + self.output = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.topk_ids = None + self.topk_weights = None + self.mc2_mask = None + self.expert_map = None + self.expand_scales = None - # Handle shared experts from metadata - shared_experts = context_metadata["shared_experts"] - if shared_experts is None: - return combined_output - - shared_act = context_metadata["shared_act"] - if self.with_quant: - swiglu_out_scale = context_metadata["swiglu_out_scale"] - shared_hidden_states, _ = shared_experts.down_proj( - (shared_act, swiglu_out_scale)) + if self.shared_experts is None: + return hidden_states else: - shared_hidden_states, _ = shared_experts.down_proj(shared_act) - - return combined_output, shared_hidden_states + if self.with_quant: + shared_hidden_states, _ = self.shared_experts.down_proj( + (self.shared_act, self.swiglu_out_scale)) + else: + shared_hidden_states, _ = self.shared_experts.down_proj( + self.shared_act) + self.shared_act = None + self.shared_experts = None + self.swiglu_out_scale = None + return hidden_states, shared_hidden_states class TokenDispatcherWithAllGather(MoETokenDispatcher): @@ -324,7 +314,14 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.apply_router_weight_on_input = False self.max_num_tokens = kwargs.get("max_num_tokens") self.num_experts_local = kwargs.get("num_local_experts", 0) + self.sorted_weights = None + self.expanded_row_idx = None + self.sorted_token_indices = None self.original_shape = None + self.mask = None + self.expert_map = None + self.topk_weights = None + self.topk_ids = None self.with_quant = False def token_dispatch(self, @@ -344,6 +341,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.original_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() + self.expert_map = expert_map + self.topk_weights = topk_weights + self.topk_ids = topk_ids self.apply_router_weight_on_input = apply_router_weight_on_input if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 @@ -357,7 +357,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): if expert_map is not None: global_num_experts = len(expert_map) mask = (expert_map[topk_ids] != -1) - topk_weights = topk_weights * mask + self.topk_weights = topk_weights * mask first_expert_idx = get_ep_group( ).rank_in_group * self.num_experts_local last_expert_idx = first_expert_idx + self.num_experts_local @@ -366,7 +366,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): last_expert_idx = self.num_experts_local global_num_experts = self.num_experts_local - sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = ( + sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = ( torch_npu.npu_moe_init_routing_v2( hidden_states, topk_ids, @@ -379,31 +379,29 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): )) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode - context_metadata = { - "topk_weights": topk_weights, - "expanded_row_idx": expanded_row_idx - } return { "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, "group_list": expert_tokens, "dynamic_scale": pertoken_scale if self.with_quant else None, - "context_metadata": context_metadata } def token_combine(self, hidden_states: torch.Tensor, - context_metadata: dict, bias: torch.Tensor = None): assert self.original_shape is not None final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, - sorted_indices=torch.abs(context_metadata["expanded_row_idx"]), - probs=context_metadata["topk_weights"]) + sorted_indices=torch.abs(self.expanded_row_idx), + probs=self.topk_weights) if len(self.original_shape) == 3: final_hidden_states = final_hidden_states.view(self.original_shape) # these values are no longer used, so they need to be set to None for memory release. + self.expert_map = None + self.topk_weights = None + self.topk_ids = None + self.expanded_row_idx = None return final_hidden_states @@ -452,12 +450,11 @@ class TokenDispatcherWithMoge(MoETokenDispatcher): "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, "group_list": group_list, - "topk_scales": topk_scales + "topk_scales": topk_scales, } def token_combine(self, hidden_states: torch.Tensor, - context_metadata: dict, bias: torch.Tensor = None): unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( torch.int32) @@ -481,8 +478,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.num_local_experts = kwargs.get("num_local_experts", 0) self.hidden_shape = None + self.topk_weights = None + self.input_splits = None + self.output_splits = None self.hidden_shape_before_permute = None + # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = None + + # cached intermediate tensors. + self.tokens_per_expert = None + self.global_input_tokens_local_experts_indices = None + assert self.num_local_experts > 0, "Expected at least one expert" if self.num_local_experts > 1: self.expert_ids_per_ep_rank = torch.tensor( @@ -504,116 +512,96 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.local_expert_indices[i + 1] - 1), "local_expert_indices must be continuous" - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - ): + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False): self.with_quant = with_quant self.hidden_shape = hidden_states.shape + self.topk_weights = topk_weights + assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" + assert topk_ids.dim() == 2, "Expected 2D tensor for routing map" if log2phy is not None: topk_ids = log2phy[topk_ids] - ( - permutated_local_input_tokens, - reversed_local_input_permutation_mapping, - tokens_per_expert, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - global_input_tokens_local_experts_indices, - ) = self._dispatch_preprocess(hidden_states, topk_ids) + permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess( + hidden_states, topk_ids) + self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping dynamic_scale_after_all2all = None if self.with_quant: permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( permutated_local_input_tokens) + _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( - dynamic_scale, output_splits, input_splits, self.ep_group) + dynamic_scale, + self.output_splits, + self.input_splits, + self.ep_group, + ) permute2_ep_all_to_all_handle.wait() dynamic_scale.untyped_storage().resize_(0) _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( - permutated_local_input_tokens, output_splits, input_splits, - self.ep_group) + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + self.ep_group, + ) permute1_ep_all_to_all_handle.wait() permutated_local_input_tokens.untyped_storage().resize_(0) - # Postprocess - global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess( - global_input_tokens, dynamic_scale_after_all2all, - global_input_tokens_local_experts_indices) - - context_metadata = { - "input_splits": - input_splits, - "output_splits": - output_splits, - "topk_weights": - topk_weights, - "reversed_local_input_permutation_mapping": - reversed_local_input_permutation_mapping, - "reversed_global_input_permutation_mapping": - reversed_global_input_permutation_mapping - } - + global_input_tokens, dynamic_scale = self._dispatch_postprocess( + global_input_tokens, dynamic_scale_after_all2all) return { "hidden_states": global_input_tokens, "group_list": tokens_per_expert, - "group_list_type": 1, - "dynamic_scale": dynamic_scale_final, - "context_metadata": context_metadata, + "dynamic_scale": dynamic_scale, + "group_list_type": 1 } - def token_combine( - self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None, - ): + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." - # 1. Preprocess using metadata - hidden_states = self._combine_preprocess(hidden_states, - context_metadata) + hidden_states = self._combine_preprocess(hidden_states) - # 2. AllToAll + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] _, permutated_local_input_tokens, handle = async_all_to_all( - hidden_states, - context_metadata["input_splits"], - context_metadata["output_splits"], - self.ep_group, - ) + hidden_states, self.input_splits, self.output_splits, + self.ep_group) handle.wait() hidden_states.untyped_storage().resize_(0) - # 3. Postprocess using metadata - output = self._combine_postprocess(permutated_local_input_tokens, - context_metadata) + output = self._combine_postprocess(permutated_local_input_tokens) + + # these values are no longer used, so they need to be set to None for memory release. + self.input_splits = None + self.output_splits = None + self.num_global_tokens_per_local_expert = None + self.topk_weights = None + self.reversed_local_input_permutation_mapping = None + self.reversed_global_input_permutation_mapping = None + self.global_input_tokens_local_experts_indices = None return output def _dispatch_preprocess(self, hidden_states, topk_ids): assert self.hidden_shape is not None - hidden_states = hidden_states.view(-1, hidden_states.size(-1)) - ( - tokens_per_expert, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - global_input_tokens_local_experts_indices, - ) = self._preprocess(topk_ids) + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self._preprocess(topk_ids) self.hidden_shape_before_permute = hidden_states.shape @@ -622,88 +610,82 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): indices=topk_ids, num_out_tokens=self.num_out_tokens, ) + return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert - return ( - permutated_local_input_tokens, - reversed_local_input_permutation_mapping, - tokens_per_expert, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - global_input_tokens_local_experts_indices, - ) - - def _preprocess(self, topk_ids: torch.Tensor): + def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts) ep_size = self.ep_size + + # Dropless self.num_out_tokens = topk_ids.numel() - input_splits = (num_local_tokens_per_expert.reshape( + # =================================================== + # Calculate input_splits, output_splits for alltoall-v. + # =================================================== + self.input_splits = (num_local_tokens_per_expert.reshape( ep_size, self.num_local_experts).sum(axis=1).to(torch.device("cpu"), non_blocking=True).numpy()) - num_global_tokens_per_expert = gather_from_sequence_parallel_region( num_local_tokens_per_expert, group=self.ep_group).reshape(ep_size, self.num_experts) - num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ 0]:self.local_expert_indices[-1] + 1] - if num_global_tokens_per_local_expert is None: + if self.num_global_tokens_per_local_expert is None: raise ValueError( "num_global_tokens_per_local_expert must be set before sum.") - - output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to( - torch.device("cpu"), non_blocking=True).numpy()) - num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum( + self.output_splits = (self.num_global_tokens_per_local_expert.sum( + axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( axis=0) + # =================================================== + # num_global_tokens_per_expert: [ep_size, num_experts] + # num_global_tokens_per_local_expert: [ep_size, num_local_experts] + # num_tokens_per_local_expert: [num_local_experts] + # =================================================== - global_input_tokens_local_experts_indices = None if self.num_local_experts > 1: - if num_global_tokens_per_local_expert is None: + if self.num_global_tokens_per_local_expert is None: raise ValueError( "num_global_tokens_per_local_expert must be set before operations." ) - global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, - num_global_tokens_per_local_expert.ravel()) + self.num_global_tokens_per_local_expert.ravel()) else: + # TODO: This full synchronization can be a performance bottleneck. + # A more granular sync (e.g., blocking D2H copies) should be investigated. torch.npu.synchronize() - return ( - num_tokens_per_local_expert, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - global_input_tokens_local_experts_indices, - ) + return num_tokens_per_local_expert - def _dispatch_postprocess(self, global_input_tokens, - dynamic_scale_after_all2all, - global_input_tokens_local_experts_indices): + def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None): # Early return if no local experts or no tokens if self.num_local_experts <= 1: - return global_input_tokens, dynamic_scale_after_all2all, None + return global_input_tokens, None # Handle quantized case if self.with_quant: - assert global_input_tokens_local_experts_indices is not None, \ - "global_input_tokens_local_experts_indices must be provided" - expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze( + assert self.global_input_tokens_local_experts_indices is not None, \ + "global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess" + expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze( -1) - active_num = global_input_tokens_local_experts_indices.numel() + active_num = self.global_input_tokens_local_experts_indices.numel() + # Handle case with no active tokens if active_num <= 0: - reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices - return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping + self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices + return global_input_tokens, dynamic_scale - global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( + # Process with active tokens + global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( global_input_tokens, expert_idx_2d, - scale=dynamic_scale_after_all2all, + scale=dynamic_scale, active_num=active_num, expert_capacity=0, expert_num=self.num_local_experts, @@ -711,34 +693,32 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): expert_tokens_num_flag=True, active_expert_range=[0, self.num_local_experts], quant_mode=-1, - row_idx_type=0, - ) - return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping + row_idx_type=0) + return global_input_tokens, expanded_scale - # Non-quantized case - global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( - global_input_tokens, global_input_tokens_local_experts_indices) - return global_input_tokens, None, reversed_global_input_permutation_mapping + # Handle non-quantized case + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, + self.global_input_tokens_local_experts_indices) + return global_input_tokens, None - def _combine_preprocess(self, hidden_states: torch.Tensor, - context_metadata: dict) -> torch.Tensor: + def _combine_preprocess(self, hidden_states): # Unpermutation 2: expert output to AlltoAll input if hidden_states.shape[0] > 0 and self.num_local_experts > 1: - rev_global = context_metadata[ - "reversed_global_input_permutation_mapping"] hidden_states = torch_npu.npu_moe_token_unpermute( - hidden_states, rev_global) + hidden_states, self.reversed_global_input_permutation_mapping) + return hidden_states - def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, - context_metadata: dict) -> torch.Tensor: + def _combine_postprocess(self, permutated_local_input_tokens): # Unpermutation 1: AlltoAll output to output output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, - sorted_indices=context_metadata[ - "reversed_local_input_permutation_mapping"].to(torch.int32), - probs=context_metadata["topk_weights"], - restore_shape=self.hidden_shape_before_permute, - ) + sorted_indices=self.reversed_local_input_permutation_mapping.to( + torch.int32), + probs=self.topk_weights, + restore_shape=self.hidden_shape_before_permute) + + # Reshape the output tensor output = output.view(self.hidden_shape) return output