diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 2a3383c..724c1be 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -15,17 +15,17 @@ # limitations under the License. # This file is a part of the vllm-ascend project. -import unittest -from unittest import mock +from unittest.mock import MagicMock, PropertyMock, patch import pytest import torch from pytest_mock import MockerFixture -from tests.ut.base import PytestBase +from tests.ut.base import PytestBase, TestBase from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig, - TokenDispatcherWithAllGather, TokenDispatcherWithMC2) + TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, + TokenDispatcherWithMC2) from vllm_ascend.utils import adapt_patch # noqa E402 @@ -70,40 +70,40 @@ class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase): assert dispatcher.overlap_stream is not None -class TestTokenDispatcherWithMC2(unittest.TestCase): +class TestTokenDispatcherWithMC2(TestBase): def setUp(self): - self.mc2_group = mock.MagicMock() + self.mc2_group = MagicMock() self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123" self.mc2_group.rank_in_group = 0 self.mc2_group.world_size = 8 - self.mc2_group_patch = mock.patch( + self.mc2_group_patch = patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group", return_value=self.mc2_group) self.mc2_group_patch.start() - self.rank_group_patch = mock.patch("torch.distributed.get_rank", - return_value=0) + self.rank_group_patch = patch("torch.distributed.get_rank", + return_value=0) self.rank_group_patch.start() # Mock get_forward_context().mc2_mask - self.forward_context = mock.MagicMock() + self.forward_context = MagicMock() self.forward_context.mc2_mask = torch.tensor([1, 0, 1]) - self.forward_context_patch = mock.patch( + self.forward_context_patch = patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context", return_value=self.forward_context) self.forward_context_patch.start() # Mock get_ascend_soc_version() - self.ascend_soc_version_patch = mock.patch( + self.ascend_soc_version_patch = patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version", return_value=AscendSocVersion.A3) self.ascend_soc_version_patch.start() # Mock get_ascend_config() - self.ascend_config = mock.MagicMock() + self.ascend_config = MagicMock() self.ascend_config.torchair_graph_config.enabled = False - self.ascend_config_patch = mock.patch( + self.ascend_config_patch = patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config", return_value=self.ascend_config) self.ascend_config_patch.start() @@ -127,13 +127,13 @@ class TestTokenDispatcherWithMC2(unittest.TestCase): self.assertTrue(self.dispatcher.need_extra_args) self.assertTrue(self.dispatcher.a3_need_extra_args) - def test_get_permute_mc2_kwargs_without_quant(self): + def test_get_dispatch_mc2_kwargs_without_quant(self): hidden_states = torch.randn(10, 128) topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - kwargs = self.dispatcher.get_permute_mc2_kwargs( + kwargs = self.dispatcher.get_dispatch_mc2_kwargs( hidden_states, topk_weights, topk_ids, expert_map) self.assertIn("x", kwargs) self.assertIn("expert_ids", kwargs) @@ -145,17 +145,16 @@ class TestTokenDispatcherWithMC2(unittest.TestCase): topk_ids = torch.randint(0, 8, (10, 1)) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2", - return_value=(torch.randn(10, 128), ) * - 5) as mock_dispatch: - output = self.dispatcher.token_permutation(hidden_states, - topk_weights, topk_ids, - expert_map) + with patch("torch_npu.npu_moe_distribute_dispatch_v2", + return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch: + output = self.dispatcher.token_dispatch(hidden_states, + topk_weights, topk_ids, + expert_map) mock_dispatch.assert_called_once() self.assertEqual(output[0], 1) # group_list_type == 1 - def test_token_permutation_with_shared_experts_and_quant(self): - self.shared_experts = mock.MagicMock() + def test_token_dispatch_with_shared_experts_and_quant(self): + self.shared_experts = MagicMock() self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128), torch.tensor(1.0)) self.shared_experts.act_fn.return_value = torch.randn(10, 128) @@ -165,15 +164,15 @@ class TestTokenDispatcherWithMC2(unittest.TestCase): self.hidden_states = torch.randn(10, 128) self.topk_weights = torch.randn(10, 1) - with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2", - return_value=(torch.randn(10, 128), ) * 5): - with mock.patch( + with patch("torch_npu.npu_moe_distribute_dispatch_v2", + return_value=(torch.randn(10, 128), ) * 5): + with patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", autospec=True): - with mock.patch( + with patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", autospec=True) as mock_wait: - self.dispatcher.token_permutation( + self.dispatcher.token_dispatch( self.hidden_states, self.topk_weights, torch.randint(0, 8, (10, 1)), @@ -182,7 +181,7 @@ class TestTokenDispatcherWithMC2(unittest.TestCase): mock_wait.assert_any_call(self.hidden_states, self.topk_weights) - def test_get_unpermute_mc_kwargs_with_quant(self): + def test_get_combine_mc_kwargs_with_quant(self): self.dispatcher.with_quant = True hidden_states = torch.randn(10, 128) self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) @@ -193,11 +192,11 @@ class TestTokenDispatcherWithMC2(unittest.TestCase): self.dispatcher.enable_dispatch_v2 = True self.dispatcher.output = torch.randint(0, 8, (10, 1)) - kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states) + kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states) self.assertIn("tp_send_counts", kwargs) - def test_token_unpermutation_with_shared_experts(self): - self.dispatcher.shared_experts = mock.MagicMock() + def test_token_combine_with_shared_experts(self): + self.dispatcher.shared_experts = MagicMock() self.dispatcher.shared_experts.down_proj.return_value = (torch.randn( 10, 128), torch.tensor(1.0)) self.dispatcher.shared_act = torch.randn(10, 128) @@ -212,18 +211,18 @@ class TestTokenDispatcherWithMC2(unittest.TestCase): self.dispatcher.output = torch.randint(0, 8, (10, 1)) self.hidden_states = torch.randn(10, 128) - with mock.patch("torch_npu.npu_moe_distribute_combine_v2", - return_value=torch.randn(10, 128)): - with mock.patch( + with patch("torch_npu.npu_moe_distribute_combine_v2", + return_value=torch.randn(10, 128)): + with patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", autospec=True): - with mock.patch( + with patch( "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", autospec=True): - self.dispatcher.token_unpermutation(self.hidden_states) + self.dispatcher.token_combine(self.hidden_states) -class TestTokenDispatcherWithAllGather(unittest.TestCase): +class TestTokenDispatcherWithAllGather(TestBase): def setUp(self): # Mock dependencies @@ -238,8 +237,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): self.dispatcher = TokenDispatcherWithAllGather(**kwargs) # Mock NPU functions - self.patcher_moe_init_routing = mock.patch( - 'torch_npu.npu_moe_init_routing') + self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing') self.mock_moe_init_routing = self.patcher_moe_init_routing.start() self.mock_moe_init_routing.return_value = ( torch.randn(6, 128), # sorted_hidden_states @@ -247,14 +245,14 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx ) - self.patcher_moe_compute_expert_tokens = mock.patch( + self.patcher_moe_compute_expert_tokens = patch( 'torch_npu.npu_moe_compute_expert_tokens') self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start( ) self.mock_moe_compute_expert_tokens.return_value = torch.tensor( [3, 3]) # expert_tokens - self.patcher_moe_finalize_routing = mock.patch( + self.patcher_moe_finalize_routing = patch( 'torch_npu.npu_moe_finalize_routing') self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( ) @@ -265,12 +263,12 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): self.patcher_moe_compute_expert_tokens.stop() self.patcher_moe_finalize_routing.stop() - def test_token_permutation_without_expert_map(self): + def test_token_dispatch_without_expert_map(self): hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation( + group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch( hidden_states, topk_weights, topk_ids, None) # Verify npu_moe_init_routing is called @@ -279,7 +277,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): self.assertEqual(group_list_type, 0) - def test_token_permutation_with_quant(self): + def test_token_dispatch_with_quant(self): kwargs = { "apply_router_weight_on_input": False, "top_k": 2, @@ -294,13 +292,13 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation( + group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch( hidden_states, topk_weights, topk_ids, None) # Verify quant mode returns group_list_type=1 self.assertEqual(group_list_type, 0) - def test_token_unpermutation_with_expert_map(self): + def test_token_combine_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) self.dispatcher.sorted_weights = torch.tensor( @@ -309,13 +307,12 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) - final_hidden_states = self.dispatcher.token_unpermutation( - hidden_states) + final_hidden_states = self.dispatcher.token_combine(hidden_states) # Verify index_add_ is applied correctly self.assertEqual(final_hidden_states.shape, (3, 128)) - def test_token_unpermutation_without_expert_map(self): + def test_token_combine_without_expert_map(self): self.dispatcher.with_quant = False self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) @@ -326,8 +323,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) - final_hidden_states = self.dispatcher.token_unpermutation( - hidden_states) + final_hidden_states = self.dispatcher.token_combine(hidden_states) # Verify npu_moe_finalize_routing is called self.mock_moe_finalize_routing.assert_called_once() @@ -335,22 +331,231 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase): self.assertEqual(final_hidden_states.shape, (3, 128)) - def test_token_permutation_with_router_weight(self): + def test_token_dispatch_with_router_weight(self): self.dispatcher.apply_router_weight_on_input = True hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 topk_ids = torch.tensor([[0], [1], [2]]) - group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation( + group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch( hidden_states, topk_weights, topk_ids, None) self.assertEqual(sorted_hidden_states.shape, (6, 128)) - def test_token_permutation_invalid_topk_when_router_weight(self): + def test_token_dispatch_invalid_topk_when_router_weight(self): self.dispatcher.apply_router_weight_on_input = True hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) with self.assertRaises(AssertionError): - self.dispatcher.token_permutation( + self.dispatcher.token_dispatch( hidden_states, topk_weights, torch.tensor([[0, 1], [1, 2], [2, 3]]), None) + + +class TestTokenDispatcherWithAll2AllV(TestBase): + + def setUp(self): + # Patch properties + patcher1 = patch.object(TokenDispatcherWithAll2AllV, + 'ep_group', + new_callable=PropertyMock, + return_value=MagicMock()) + patcher2 = patch.object(TokenDispatcherWithAll2AllV, + 'ep_rank', + new_callable=PropertyMock, + return_value=0) + patcher3 = patch.object(TokenDispatcherWithAll2AllV, + 'ep_size', + new_callable=PropertyMock, + return_value=2) + + self.addCleanup(patcher1.stop) + self.addCleanup(patcher2.stop) + self.addCleanup(patcher3.stop) + + self.mock_ep_group_prop = patcher1.start() + self.mock_ep_rank_prop = patcher2.start() + self.mock_ep_size_prop = patcher3.start() + + # Mock torch_npu.npu_moe_token_permute + patcher4 = patch('torch_npu.npu_moe_token_permute') + self.mock_npu_moe_token_permute = patcher4.start() + self.addCleanup(patcher4.stop) + self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16), + torch.arange(16)) + + # Mock torch_npu.npu_moe_token_unpermute + patcher5 = patch('torch_npu.npu_moe_token_unpermute') + self.mock_npu_moe_token_unpermute = patcher5.start() + self.addCleanup(patcher5.stop) + self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16) + + # Mock async_all_to_all + patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all') + self.mock_async_all_to_all = patcher6.start() + self.addCleanup(patcher6.stop) + self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16), + MagicMock()) + + # Mock gather_from_sequence_parallel_region + patcher7 = patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region' + ) + self.mock_gather_from_sequence_parallel_region = patcher7.start() + self.addCleanup(patcher7.stop) + self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor( + [[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64) + + # Mock torch.histc + patcher8 = patch('torch.histc') + self.mock_histc = patcher8.start() + self.addCleanup(patcher8.stop) + self.mock_histc.return_value = torch.tensor([2, 2, 2, 2], + dtype=torch.int64) + + # Mock torch.npu.current_device + patcher9 = patch('torch.npu.current_device') + self.mock_current_device = patcher9.start() + self.addCleanup(patcher9.stop) + self.mock_current_device.return_value = 'cpu' + + # Mock torch_npu.npu_dynamic_quant + patcher10 = patch('torch_npu.npu_dynamic_quant') + self.mock_npu_dynamic_quant = patcher10.start() + self.addCleanup(patcher10.stop) + self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16), + torch.randn(16)) + + # Mock torch_npu.npu_moe_init_routing_v2 + patcher11 = patch('torch_npu.npu_moe_init_routing_v2') + self.mock_npu_moe_init_routing_v2 = patcher11.start() + self.addCleanup(patcher11.stop) + self.mock_npu_moe_init_routing_v2.return_value = (torch.randn( + 16, 16), torch.arange(16), None, torch.randn(16)) + + # Mock torch.repeat_interleave + patcher12 = patch('torch.repeat_interleave') + self.mock_repeat_interleave = patcher12.start() + self.addCleanup(patcher12.stop) + self.mock_repeat_interleave.return_value = torch.arange(16) + + self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, + num_experts=4, + num_local_experts=2, + with_quant=False) + + def test_token_dispatch(self): + hidden_states = torch.randn(8, 16) + topk_weights = torch.rand(8, 4) + topk_ids = torch.randint(0, 4, (8, 2)).long() + expert_map = torch.tensor([0, 1, 2, 3]) + + self.dispatcher.expert_ids_per_ep_rank = torch.tensor( + [0, 1], dtype=torch.int32) + self.dispatcher.local_expert_indices = [0, 1] + + result = self.dispatcher.token_dispatch(hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map) + + self.assertIsNotNone(result["hidden_states"]) + self.assertIsNotNone(result["group_list"]) + self.assertEqual(result["group_list_type"], 1) + + def test_token_combine(self): + self.dispatcher.hidden_shape = (8, 16) + self.dispatcher.hidden_shape_before_permute = (8, 16) + self.dispatcher.reversed_local_input_permutation_mapping = torch.arange( + 8) + self.dispatcher.topk_weights = torch.rand(8, 4) + self.dispatcher.input_splits = [4, 4] + self.dispatcher.output_splits = [4, 4] + self.dispatcher.reversed_global_input_permutation_mapping = torch.arange( + 16) + + self.dispatcher.expert_ids_per_ep_rank = torch.tensor( + [0, 1], dtype=torch.int32) + self.dispatcher.local_expert_indices = [0, 1] + self.dispatcher.num_global_tokens_per_local_expert = torch.tensor( + [[2, 2], [2, 2]], dtype=torch.int64) + + expert_output = torch.randn(16, 16) + output = self.dispatcher.token_combine(expert_output) + + self.assertIsNotNone(output) + self.assertEqual(output.shape, (8, 16)) + + def test_token_dispatch_with_quant(self): + self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, + num_experts=4, + num_local_experts=2, + with_quant=True) + + hidden_states = torch.randn(8, 16) + topk_weights = torch.rand(8, 4) + topk_ids = torch.randint(0, 4, (8, 2)).long() + expert_map = torch.tensor([0, 1, 2, 3]) + + self.dispatcher.expert_ids_per_ep_rank = torch.tensor( + [0, 1], dtype=torch.int32) + self.dispatcher.local_expert_indices = [0, 1] + + result = self.dispatcher.token_dispatch(hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map) + + self.assertIsNotNone(result["hidden_states"]) + self.assertIsNotNone(result["group_list"]) + self.assertIsNotNone(result["dynamic_scale"]) + self.assertEqual(result["group_list_type"], 1) + + def test_token_dispatch_with_quant_no_active_tokens(self): + self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, + num_experts=4, + num_local_experts=2, + with_quant=True) + + self.mock_repeat_interleave.return_value = torch.tensor( + [], dtype=torch.long) + + hidden_states = torch.randn(8, 16) + topk_weights = torch.rand(8, 4) + topk_ids = torch.randint(0, 4, (8, 2)).long() + expert_map = torch.tensor([0, 1, 2, 3]) + + self.dispatcher.expert_ids_per_ep_rank = torch.tensor( + [0, 1], dtype=torch.int32) + self.dispatcher.local_expert_indices = [0, 1] + + result = self.dispatcher.token_dispatch(hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map) + + self.assertIsNotNone(result["hidden_states"]) + self.assertIsNotNone(result["group_list"]) + self.assertIsNotNone(result["dynamic_scale"]) + self.assertEqual(result["group_list_type"], 1) + + def test_token_dispatch_with_log2phy(self): + hidden_states = torch.randn(8, 16) + topk_weights = torch.rand(8, 4) + topk_ids = torch.randint(0, 4, (8, 2)).long() + expert_map = torch.tensor([0, 1, 2, 3]) + log2phy = torch.tensor([1, 0, 3, 2]) + + self.dispatcher.expert_ids_per_ep_rank = torch.tensor( + [0, 1], dtype=torch.int32) + self.dispatcher.local_expert_indices = [0, 1] + + result = self.dispatcher.token_dispatch(hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + log2phy=log2phy) + + self.assertIsNotNone(result["hidden_states"]) + self.assertIsNotNone(result["group_list"]) + self.assertEqual(result["group_list_type"], 1) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index b3cc46c..e02652d 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -22,7 +22,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Optional import torch import torch_npu @@ -466,8 +466,9 @@ class MoETokenDispatcher(ABC): """ Initialize the MoE Token Dispatcher. """ - self.top_k = kwargs.get("top_k") - self.num_experts = kwargs.get("num_experts") + self.top_k = kwargs.get("top_k", 0) + self.num_experts = kwargs.get("num_experts", 0) + self.with_quant = kwargs.get("with_quant", False) @property def ep_group(self): @@ -483,25 +484,25 @@ class MoETokenDispatcher(ABC): return get_ep_group().world_size @abstractmethod - def token_permutation( + def token_dispatch( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - shared_experts: Optional[Any] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.Tensor] = None, ): raise NotImplementedError("Dispatch function not implemented.") @abstractmethod - def token_unpermutation(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): - raise NotImplementedError("Restore function not implemented.") + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + raise NotImplementedError("Combine function not implemented.") class TokenDispatcherWithMC2(MoETokenDispatcher): @@ -517,7 +518,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.ep_world_size = get_mc2_group().world_size ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.with_quant = kwargs.get("with_quant") self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 @@ -535,12 +535,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.topk_weights = None self.shared_experts = None - def get_permute_mc2_kwargs(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - global_redundant_expert_num: int = 0): + def get_dispatch_mc2_kwargs(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + global_redundant_expert_num: int = 0): quant_mode = 0 forward_context = get_forward_context() mc2_mask = forward_context.mc2_mask @@ -581,26 +581,26 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 - def token_permutation( + def token_dispatch( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - shared_experts: Optional[Any] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.Tensor] = None, ): self.expert_map = expert_map self.topk_ids = topk_ids self.topk_weights = topk_weights self.shared_experts = shared_experts - kwargs_mc2 = self.get_permute_mc2_kwargs(hidden_states, topk_weights, - topk_ids, expert_map, - global_redundant_expert_num) + kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, + topk_ids, expert_map, + global_redundant_expert_num) self.output = torch_npu.npu_moe_distribute_dispatch_v2( **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( @@ -629,7 +629,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): group_list_type = 1 return group_list_type, expand_x, expert_token_nums - def get_unpermute_mc_kwargs(self, hidden_states: torch.Tensor): + def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): assert self.expert_map is not None assert self.topk_weights is not None assert self.topk_ids is not None @@ -682,11 +682,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 - def token_unpermutation(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): - kwargs_mc2 = self.get_unpermute_mc_kwargs(hidden_states) + kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) hidden_states = torch_npu.npu_moe_distribute_combine_v2( **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( @@ -718,7 +718,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): ep_size = kwargs.get("ep_size") if ep_size is not None: self.num_experts_local = self.num_experts // ep_size - self.with_quant = kwargs.get("with_quant") self.sorted_weights = None self.expanded_row_idx = None self.sorted_token_indices = None @@ -728,17 +727,17 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.topk_weights = None self.topk_ids = None - def token_permutation( + def token_dispatch( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - shared_experts: Optional[Any] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.Tensor] = None, ): self.original_shape = hidden_states.shape # assert len(original_shape) == 2 @@ -830,9 +829,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): group_list_type = 0 return group_list_type, sorted_hidden_states, expert_tokens - def token_unpermutation(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): assert self.mask is not None assert self.sorted_token_indices is not None assert self.sorted_weights is not None @@ -901,23 +900,22 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): "apply_router_weight_on_input") ep_size = kwargs.get("ep_size") self.local_ep = ep_size - self.top_k = kwargs.get("top_k") assert self.local_ep is not None self.local_num_experts = self.num_experts // self.local_ep self.local_num_group = self.top_k // self.local_ep self.bsz = None - def token_permutation( + def token_dispatch( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - shared_experts: Optional[Any] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.Tensor] = None, ): if self.apply_router_weight_on_input: @@ -949,9 +947,9 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) return hidden_states, group_list - def token_unpermutation(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): assert self.local_ep is not None unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( torch.int32) @@ -960,3 +958,253 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): final_hidden_states = unsorted_hidden_states.reshape( self.bsz, self.top_k // self.local_ep, -1).sum(1) return final_hidden_states + + +class TokenDispatcherWithAll2AllV(MoETokenDispatcher): + """ + The implementation of the AlltoAll-based token dispatcher, which handles token + dispatching on the sequence level instead of token level. The core of this implementation + lies in each device dispatching on the entire sequence, with the hidden state being partitioned. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.num_local_experts = kwargs.get("num_local_experts", 0) + self.num_global_redundant_experts = kwargs.get( + "num_global_redundant_experts", 0) + self.num_experts = self.num_experts + self.num_global_redundant_experts + + self.hidden_shape = None + self.topk_weights = None + self.input_splits = None + self.output_splits = None + self.hidden_shape_before_permute = None + + # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = None + + # cached intermediate tensors. + self.tokens_per_expert = None + self.global_input_tokens_local_experts_indices = None + + assert self.num_local_experts > 0, "Expected at least one expert" + if self.num_local_experts > 1: + self.expert_ids_per_ep_rank = torch.tensor( + [i % self.num_local_experts for i in range(self.num_experts)], + dtype=torch.int32, + device=torch.npu.current_device(), + ) + + local_expert_indices_offset = (self.ep_rank * self.num_local_experts) + + self.local_expert_indices = [ + local_expert_indices_offset + i + for i in range(self.num_local_experts) + ] + assert (len(self.local_expert_indices) == self.num_local_experts + ), "Invalid local expert indices" + for i in range(len(self.local_expert_indices) - 1): + assert (self.local_expert_indices[i] == + self.local_expert_indices[i + 1] - + 1), "local_expert_indices must be continuous" + + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.Tensor] = None, + ): + self.hidden_shape = hidden_states.shape + self.topk_weights = topk_weights + assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" + assert topk_ids.dim() == 2, "Expected 2D tensor for routing map" + + if log2phy is not None: + topk_ids = log2phy[topk_ids] + + permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess( + hidden_states, topk_ids) + self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping + + dynamic_scale_after_all2all = None + if self.with_quant: + permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( + permutated_local_input_tokens) + + _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( + dynamic_scale, + self.output_splits, + self.input_splits, + self.ep_group, + ) + permute2_ep_all_to_all_handle.wait() + dynamic_scale.untyped_storage().resize_(0) + + _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + self.ep_group, + ) + permute1_ep_all_to_all_handle.wait() + permutated_local_input_tokens.untyped_storage().resize_(0) + + global_input_tokens, dynamic_scale = self._dispatch_postprocess( + global_input_tokens, dynamic_scale_after_all2all) + return { + "hidden_states": global_input_tokens, + "group_list": tokens_per_expert, + "dynamic_scale": dynamic_scale, + "group_list_type": 1 + } + + def token_combine(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" + + hidden_states = self._combine_preprocess(hidden_states) + + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, permutated_local_input_tokens, handle = async_all_to_all( + hidden_states, self.input_splits, self.output_splits, + self.ep_group) + handle.wait() + hidden_states.untyped_storage().resize_(0) + + output = self._combine_postprocess(permutated_local_input_tokens) + + self.input_splits = None + self.output_splits = None + self.num_global_tokens_per_local_expert = None + + return output + + def _dispatch_preprocess(self, hidden_states, topk_ids): + assert self.hidden_shape is not None + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self._preprocess(topk_ids) + + self.hidden_shape_before_permute = hidden_states.shape + + permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=topk_ids, + num_out_tokens=self.num_out_tokens, + ) + return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert + + def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor: + num_local_tokens_per_expert = torch.histc(topk_ids, + bins=self.num_experts, + min=0, + max=self.num_experts) + + ep_size = self.ep_size + + # Dropless + self.num_out_tokens = topk_ids.numel() + + # =================================================== + # Calculate input_splits, output_splits for alltoall-v. + # =================================================== + self.input_splits = (num_local_tokens_per_expert.reshape( + ep_size, + self.num_local_experts).sum(axis=1).to(torch.device("cpu"), + non_blocking=True).numpy()) + num_global_tokens_per_expert = gather_from_sequence_parallel_region( + num_local_tokens_per_expert, + group=self.ep_group).reshape(ep_size, self.num_experts) + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + 0]:self.local_expert_indices[-1] + 1] + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before sum.") + self.output_splits = (self.num_global_tokens_per_local_expert.sum( + axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( + axis=0) + # =================================================== + # num_global_tokens_per_expert: [ep_size, num_experts] + # num_global_tokens_per_local_expert: [ep_size, num_local_experts] + # num_tokens_per_local_expert: [num_local_experts] + # =================================================== + + if self.num_local_experts > 1: + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) + + return num_tokens_per_local_expert + + def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None): + # Early return if no local experts or no tokens + if self.num_local_experts <= 1: + return global_input_tokens, None + + # Handle quantized case + if self.with_quant: + assert self.global_input_tokens_local_experts_indices is not None, \ + "global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess" + expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze( + -1) + active_num = self.global_input_tokens_local_experts_indices.numel() + + # Handle case with no active tokens + if active_num <= 0: + self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices + return global_input_tokens, dynamic_scale + + # Process with active tokens + global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( + global_input_tokens, + expert_idx_2d, + scale=dynamic_scale, + active_num=active_num, + expert_capacity=0, + expert_num=self.num_local_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[0, self.num_local_experts], + quant_mode=-1, + row_idx_type=0) + return global_input_tokens, expanded_scale + + # Handle non-quantized case + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, + self.global_input_tokens_local_experts_indices) + return global_input_tokens, None + + def _combine_preprocess(self, hidden_states): + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, self.reversed_global_input_permutation_mapping) + + return hidden_states + + def _combine_postprocess(self, permutated_local_input_tokens): + # Unpermutation 1: AlltoAll output to output + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=permutated_local_input_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping.to( + torch.int32), + probs=self.topk_weights, + restore_shape=self.hidden_shape_before_permute) + + # Reshape the output tensor + output = output.view(self.hidden_shape) + return output