# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. from unittest.mock import MagicMock, PropertyMock, patch import numpy as np import pytest import torch from tests.ut.base import TestBase from vllm_ascend.ops.fused_moe.moe_runtime_args import ( MoEAllGatherCombineMetadata, MoEAllToAllCombineMetadata, MoEMC2CombineMetadata, MoEQuantParams, MoERoutingParams, MoETokenDispatchInput, ) from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip AscendDeviceType, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2, ) from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams from vllm_ascend.quantization.quant_type import QuantType def build_token_dispatch_input_fixture( *, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor | None = None, global_redundant_expert_num: int = 0, apply_router_weight_on_input: bool = False, pertoken_scale: torch.Tensor | None = None, quant_type: QuantType = QuantType.NONE, comm_quant_mode: int | None = None, act_quant_type: torch.dtype | None = None, ) -> MoETokenDispatchInput: mxfp_spec = None if quant_type == QuantType.MXFP8: mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type) return MoETokenDispatchInput( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, routing=MoERoutingParams( expert_map=expert_map, global_redundant_expert_num=global_redundant_expert_num, mc2_mask=None, apply_router_weight_on_input=apply_router_weight_on_input, pertoken_scale=pertoken_scale, ), quant=MoEQuantParams( quant_type=quant_type, comm_quant_mode=comm_quant_mode, mxfp=mxfp_spec, ), ) class TestTokenDispatcherWithMC2(TestBase): def setUp(self): self.config_patcher = patch( 'vllm_ascend.ops.fused_moe.token_dispatcher.get_current_vllm_config' ) self.mock_get_config = self.config_patcher.start() mock_config = MagicMock() mock_config.scheduler_config.max_num_seqs = 256 mock_config.scheduler_config.decode_max_num_seqs = 256 mock_config.compilation_config.custom_ops = ["all"] mock_config.speculative_config = None mock_config.parallel_config.tensor_parallel_size = 1 self.mock_get_config.return_value = mock_config self.mc2_group = MagicMock() self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123" self.mc2_group.rank_in_group = 0 self.mc2_group.world_size = 8 self.mc2_group_patch = patch( "vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group", return_value=self.mc2_group) self.mc2_group_patch.start() self.rank_group_patch = patch("torch.distributed.get_rank", return_value=0) self.rank_group_patch.start() # Mock get_forward_context().mc2_mask self.forward_context = MagicMock() self.forward_context.mc2_mask = torch.tensor([1, 0, 1]) self.forward_context_patch = patch( "vllm.forward_context.get_forward_context", return_value=self.forward_context) self.forward_context_patch.start() # Mock get_ascend_device_type() self.ascend_soc_version_patch = patch( "vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_device_type", return_value=AscendDeviceType.A3) self.ascend_soc_version_patch.start() kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} self.dispatcher = TokenDispatcherWithMC2(**kwargs) def tearDown(self): self.mc2_group_patch.stop() self.forward_context_patch.stop() self.ascend_soc_version_patch.stop() def test_init(self): self.assertEqual(self.dispatcher.ep_rank_id, 0) self.assertEqual(self.dispatcher.ep_world_size, 8) self.assertTrue(self.dispatcher.enable_dispatch_v2) self.assertTrue(self.dispatcher.need_extra_args) def test_get_dispatch_mc2_kwargs_without_quant(self): hidden_states = torch.randn(10, 128) topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, global_redundant_expert_num=0, apply_router_weight_on_input=False, pertoken_scale=None, ) kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input) self.assertIn("x", kwargs) self.assertIn("expert_ids", kwargs) self.assertEqual(kwargs["moe_expert_num"], 8) def test_token_permutation_dispatch(self): hidden_states = torch.randn(10, 128) topk_weights = torch.randn(10, 1) topk_ids = torch.randint(0, 8, (10, 1)) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(torch.randn(10, 128), ) * 5 + (None, None)) as mock_dispatch: token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, ) output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) mock_dispatch.assert_called_once() self.assertEqual(output.group_list_type, 0) # group_list_type == 0 self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata) def test_get_combine_mc_kwargs_with_quant(self): hidden_states = torch.randn(10, 128) topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) assist_info_for_combine = torch.arange(10) combine_metadata = MoEMC2CombineMetadata( topk_ids=topk_ids, topk_weights=topk_weights, expert_map=expert_map, ep_recv_counts=ep_recv_counts, tp_recv_counts=tp_recv_counts, assist_info_for_combine=assist_info_for_combine, expand_scales=None, dispatch_with_quant=True, ) self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True self.dispatcher.moe_expert_num = len(expert_map) kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, combine_metadata) self.assertIn("tp_send_counts", kwargs) class TestTokenDispatcherWithAllGather(TestBase): def setUp(self): # Mock dependencies kwargs = { "apply_router_weight_on_input": False, "top_k": 2, "max_num_tokens": 100, "ep_size": 2, "num_experts": 128, "with_quant": False, } self.dispatcher = TokenDispatcherWithAllGather(**kwargs) # Mock NPU functions self.patcher_npu_moe_init_routing_custom = patch( 'torch.ops._C_ascend.npu_moe_init_routing_custom') self.mock_npu_moe_init_routing_custom = self.patcher_npu_moe_init_routing_custom.start( ) self.mock_npu_moe_init_routing_custom.return_value = ( torch.randn(6, 128), # sorted_hidden_states torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx torch.tensor([0, 1, 0, 1, 0, 1])) self.patcher_npu_moe_token_unpermute = patch( 'torch_npu.npu_moe_token_unpermute') self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start( ) self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128) def tearDown(self): self.patcher_npu_moe_init_routing_custom.stop() self.patcher_npu_moe_token_unpermute.stop() @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_without_expert_map(self): hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, ) results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_custom.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_custom.call_args self.assertEqual(results.group_list_type, 1) self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, ) results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_custom.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_custom.call_args self.assertEqual(results.group_list_type, 1) self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_without_quant(self): kwargs = { "apply_router_weight_on_input": False, "top_k": 2, "max_num_tokens": 100, "ep_size": 2, "num_experts": 128, } self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, ) results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertEqual(results.group_list_type, 1) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_with_quant(self): kwargs = { "apply_router_weight_on_input": False, "top_k": 2, "max_num_tokens": 100, "ep_size": 2, "num_experts": 128, } self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, quant_type=QuantType.W8A8, ) results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(results.hidden_states) self.assertIsNotNone(results.group_list) self.assertIsNotNone(results.dynamic_scale) self.assertEqual(results.group_list_type, 1) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_combine_with_expert_map(self): hidden_states = torch.randn(6, 128) combine_metadata = MoEAllGatherCombineMetadata( expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]), topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), restore_shape=torch.Size([6, 128]), ) final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata) self.assertEqual(final_hidden_states.shape, (6, 128)) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_combine_without_expert_map(self): hidden_states = torch.randn(6, 128) combine_metadata = MoEAllGatherCombineMetadata( expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]), topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), restore_shape=torch.Size([6, 128]), ) final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata) self.mock_npu_moe_token_unpermute.assert_called_once() self.assertEqual(final_hidden_states.shape, (6, 128)) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_with_router_weight(self): hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 topk_ids = torch.tensor([[0], [1], [2]]) token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, apply_router_weight_on_input=True, ) results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertEqual(results.hidden_states.shape, (6, 128)) self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata) class TestTokenDispatcherWithAll2AllV(TestBase): def setUp(self): # Patch properties patcher1 = patch.object(TokenDispatcherWithAll2AllV, 'ep_group', new_callable=PropertyMock, return_value=MagicMock()) patcher2 = patch.object(TokenDispatcherWithAll2AllV, 'ep_rank', new_callable=PropertyMock, return_value=0) patcher3 = patch.object(TokenDispatcherWithAll2AllV, 'ep_size', new_callable=PropertyMock, return_value=2) self.addCleanup(patcher1.stop) self.addCleanup(patcher2.stop) self.addCleanup(patcher3.stop) self.mock_ep_group_prop = patcher1.start() self.mock_ep_rank_prop = patcher2.start() self.mock_ep_size_prop = patcher3.start() # Mock torch_npu.npu_moe_token_permute patcher4 = patch('torch_npu.npu_moe_token_permute') self.mock_npu_moe_token_permute = patcher4.start() self.addCleanup(patcher4.stop) self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16), torch.arange(16)) # Mock torch_npu.npu_moe_token_unpermute patcher5 = patch('torch_npu.npu_moe_token_unpermute') self.mock_npu_moe_token_unpermute = patcher5.start() self.addCleanup(patcher5.stop) self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16) # Mock async_all_to_all patcher6 = patch( 'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all') self.mock_async_all_to_all = patcher6.start() self.addCleanup(patcher6.stop) self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16), MagicMock()) # Mock gather_from_sequence_parallel_region patcher7 = patch( 'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region' ) self.mock_gather_from_sequence_parallel_region = patcher7.start() self.addCleanup(patcher7.stop) self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor( [[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64) # Mock torch.histc patcher8 = patch('torch.histc') self.mock_histc = patcher8.start() self.addCleanup(patcher8.stop) self.mock_histc.return_value = torch.tensor([2, 2, 2, 2], dtype=torch.int64) # Mock torch.npu.current_device patcher9 = patch('torch.npu.current_device') self.mock_current_device = patcher9.start() self.addCleanup(patcher9.stop) self.mock_current_device.return_value = 'cpu' # Mock torch_npu.npu_dynamic_quant patcher10 = patch('torch_npu.npu_dynamic_quant') self.mock_npu_dynamic_quant = patcher10.start() self.addCleanup(patcher10.stop) self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16), torch.randn(16)) # Mock torch.ops._C_ascend.npu_moe_init_routing_custom patcher11 = patch('torch.ops._C_ascend.npu_moe_init_routing_custom') self.mock_npu_moe_init_routing_custom = patcher11.start() self.addCleanup(patcher11.stop) self.mock_npu_moe_init_routing_custom.return_value = (torch.randn( 16, 16), torch.arange(16), None, torch.randn(16)) # Mock torch.repeat_interleave patcher12 = patch('torch.repeat_interleave') self.mock_repeat_interleave = patcher12.start() self.addCleanup(patcher12.stop) self.mock_repeat_interleave.return_value = torch.arange(16) self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, num_experts=4, num_local_experts=2, with_quant=False) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch(self): hidden_states = torch.randn(8, 16) topk_weights = torch.rand(8, 4) topk_ids = torch.randint(0, 4, (8, 2)).long() expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, ) result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.group_list) self.assertEqual(result.group_list_type, 1) self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_combine(self): hidden_states = torch.randn(16, 16) combine_metadata = MoEAllToAllCombineMetadata( input_splits=np.array([4, 4]), output_splits=np.array([4, 4]), topk_weights=torch.rand(8, 4), reversed_local_input_permutation_mapping=torch.arange(8), reversed_global_input_permutation_mapping=torch.arange(16), hidden_shape=torch.Size([8, 16]), hidden_shape_before_permute=torch.Size([8, 16]), ) self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] output = self.dispatcher.token_combine(hidden_states, combine_metadata) self.assertIsNotNone(output) self.assertEqual(output.shape, (8, 16)) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_with_quant(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, num_experts=4, num_local_experts=2) hidden_states = torch.randn(8, 16) topk_weights = torch.rand(8, 4) topk_ids = torch.randint(0, 4, (8, 2)).long() expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, quant_type=QuantType.W8A8, ) result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.dynamic_scale) self.assertEqual(result.group_list_type, 1) self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_with_quant_no_active_tokens(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, num_experts=4, num_local_experts=2) self.mock_repeat_interleave.return_value = torch.tensor( [], dtype=torch.long) hidden_states = torch.randn(8, 16) topk_weights = torch.rand(8, 4) topk_ids = torch.randint(0, 4, (8, 2)).long() expert_map = torch.tensor([0, 1, 2, 3]) self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] token_dispatch_input = build_token_dispatch_input_fixture( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, quant_type=QuantType.W8A8, ) result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.dynamic_scale) self.assertEqual(result.group_list_type, 1)