Files
xc-llm-ascend/tests/ut/ops/test_token_dispatcher.py
Jade Zheng 7d5242faca [Refactor] Formatting output types related to FuseMoE (#5481)
Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
2025-12-31 14:24:37 +08:00

498 lines
21 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
AscendDeviceType, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
class TestTokenDispatcherWithMC2(TestBase):
def setUp(self):
self.mc2_group = MagicMock()
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
self.rank_group_patch = patch("torch.distributed.get_rank",
return_value=0)
self.rank_group_patch.start()
# Mock get_forward_context().mc2_mask
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch(
"vllm.forward_context.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
# Mock get_ascend_device_type()
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_device_type",
return_value=AscendDeviceType.A3)
self.ascend_soc_version_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
def tearDown(self):
self.mc2_group_patch.stop()
self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop()
def test_init(self):
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
def test_get_dispatch_mc2_kwargs_without_quant(self):
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
def test_token_permutation_dispatch(self):
hidden_states = torch.randn(10, 128)
topk_weights = torch.randn(10, 1)
topk_ids = torch.randint(0, 8, (10, 1))
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 +
(None, None)) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
self.dispatcher.with_quant = False
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
self.hidden_states = torch.randn(10, 128)
self.topk_weights = torch.randn(10, 1)
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 + (None, None)):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
assist_info_for_combine = torch.arange(10)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": mc2_mask,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"tp_recv_counts": tp_recv_counts
}
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
context_metadata)
self.assertIn("tp_send_counts", kwargs)
class TestTokenDispatcherWithAllGather(TestBase):
def setUp(self):
# Mock dependencies
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
"with_quant": False,
}
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_npu_moe_init_routing_v2 = patch(
'torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
)
self.mock_npu_moe_init_routing_v2.return_value = (
torch.randn(6, 128), # sorted_hidden_states
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
def tearDown(self):
self.patcher_npu_moe_init_routing_v2.stop()
self.patcher_npu_moe_token_unpermute.stop()
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_without_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
None)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_with_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights,
topk_ids,
None,
with_quant=True)
self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list)
self.assertIsNotNone(results.dynamic_scale)
self.assertEqual(results.group_list_type, 1)
def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128)
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata).routed_out
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
hidden_states = torch.randn(6, 128)
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata).routed_out
self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results.hidden_states.shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
def setUp(self):
# Patch properties
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
'ep_group',
new_callable=PropertyMock,
return_value=MagicMock())
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
'ep_rank',
new_callable=PropertyMock,
return_value=0)
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
'ep_size',
new_callable=PropertyMock,
return_value=2)
self.addCleanup(patcher1.stop)
self.addCleanup(patcher2.stop)
self.addCleanup(patcher3.stop)
self.mock_ep_group_prop = patcher1.start()
self.mock_ep_rank_prop = patcher2.start()
self.mock_ep_size_prop = patcher3.start()
# Mock torch_npu.npu_moe_token_permute
patcher4 = patch('torch_npu.npu_moe_token_permute')
self.mock_npu_moe_token_permute = patcher4.start()
self.addCleanup(patcher4.stop)
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
torch.arange(16))
# Mock torch_npu.npu_moe_token_unpermute
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = patcher5.start()
self.addCleanup(patcher5.stop)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch(
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
MagicMock())
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
# Mock torch.histc
patcher8 = patch('torch.histc')
self.mock_histc = patcher8.start()
self.addCleanup(patcher8.stop)
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
dtype=torch.int64)
# Mock torch.npu.current_device
patcher9 = patch('torch.npu.current_device')
self.mock_current_device = patcher9.start()
self.addCleanup(patcher9.stop)
self.mock_current_device.return_value = 'cpu'
# Mock torch_npu.npu_dynamic_quant
patcher10 = patch('torch_npu.npu_dynamic_quant')
self.mock_npu_dynamic_quant = patcher10.start()
self.addCleanup(patcher10.stop)
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
torch.randn(16))
# Mock torch_npu.npu_moe_init_routing_v2
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = patcher11.start()
self.addCleanup(patcher11.stop)
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
16, 16), torch.arange(16), None, torch.randn(16))
# Mock torch.repeat_interleave
patcher12 = patch('torch.repeat_interleave')
self.mock_repeat_interleave = patcher12.start()
self.addCleanup(patcher12.stop)
self.mock_repeat_interleave.return_value = torch.arange(16)
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=False)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
context_metadata = {
"input_splits": [4, 4],
"output_splits": [4, 4],
"topk_weights": torch.rand(8, 4),
"reversed_local_input_permutation_mapping": torch.arange(8),
"reversed_global_input_permutation_mapping": torch.arange(16),
}
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.routed_out.shape, (8, 16))
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
self.mock_repeat_interleave.return_value = torch.tensor(
[], dtype=torch.long)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
log2phy = torch.tensor([1, 0, 3, 2])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)