Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
498 lines
21 KiB
Python
498 lines
21 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
from unittest.mock import MagicMock, PropertyMock, patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
|
|
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
|
AscendDeviceType, TokenDispatcherWithAll2AllV,
|
|
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
|
|
|
|
|
class TestTokenDispatcherWithMC2(TestBase):
|
|
|
|
def setUp(self):
|
|
self.mc2_group = MagicMock()
|
|
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
|
|
self.mc2_group.rank_in_group = 0
|
|
self.mc2_group.world_size = 8
|
|
self.mc2_group_patch = patch(
|
|
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
|
|
return_value=self.mc2_group)
|
|
self.mc2_group_patch.start()
|
|
|
|
self.rank_group_patch = patch("torch.distributed.get_rank",
|
|
return_value=0)
|
|
self.rank_group_patch.start()
|
|
|
|
# Mock get_forward_context().mc2_mask
|
|
self.forward_context = MagicMock()
|
|
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
|
|
self.forward_context_patch = patch(
|
|
"vllm.forward_context.get_forward_context",
|
|
return_value=self.forward_context)
|
|
self.forward_context_patch.start()
|
|
|
|
# Mock get_ascend_device_type()
|
|
self.ascend_soc_version_patch = patch(
|
|
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_device_type",
|
|
return_value=AscendDeviceType.A3)
|
|
self.ascend_soc_version_patch.start()
|
|
|
|
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
|
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
|
|
|
def tearDown(self):
|
|
self.mc2_group_patch.stop()
|
|
self.forward_context_patch.stop()
|
|
self.ascend_soc_version_patch.stop()
|
|
|
|
def test_init(self):
|
|
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
|
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
|
self.assertFalse(self.dispatcher.with_quant)
|
|
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
|
self.assertTrue(self.dispatcher.need_extra_args)
|
|
|
|
def test_get_dispatch_mc2_kwargs_without_quant(self):
|
|
hidden_states = torch.randn(10, 128)
|
|
topk_ids = torch.randint(0, 8, (10, 1))
|
|
topk_weights = torch.randn(10, 1)
|
|
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
|
mc2_mask = None
|
|
|
|
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
|
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
|
|
self.assertIn("x", kwargs)
|
|
self.assertIn("expert_ids", kwargs)
|
|
self.assertEqual(kwargs["moe_expert_num"], 8)
|
|
|
|
def test_token_permutation_dispatch(self):
|
|
hidden_states = torch.randn(10, 128)
|
|
topk_weights = torch.randn(10, 1)
|
|
topk_ids = torch.randint(0, 8, (10, 1))
|
|
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
|
|
|
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
|
return_value=(torch.randn(10, 128), ) * 5 +
|
|
(None, None)) as mock_dispatch:
|
|
output = self.dispatcher.token_dispatch(hidden_states,
|
|
topk_weights, topk_ids,
|
|
expert_map)
|
|
mock_dispatch.assert_called_once()
|
|
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
|
|
|
def test_token_dispatch_with_shared_experts_and_quant(self):
|
|
self.shared_experts = MagicMock()
|
|
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
|
|
torch.tensor(1.0))
|
|
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
|
|
self.dispatcher.with_quant = False
|
|
self.dispatcher.shared_act = torch.randn(10, 128)
|
|
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
|
|
self.hidden_states = torch.randn(10, 128)
|
|
self.topk_weights = torch.randn(10, 1)
|
|
|
|
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
|
return_value=(torch.randn(10, 128), ) * 5 + (None, None)):
|
|
self.dispatcher.token_dispatch(self.hidden_states,
|
|
self.topk_weights,
|
|
torch.randint(0, 8, (10, 1)),
|
|
torch.tensor(
|
|
[0, 1, 2, 3, 4, 5, 6, 7]),
|
|
shared_experts=self.shared_experts)
|
|
|
|
def test_get_combine_mc_kwargs_with_quant(self):
|
|
self.dispatcher.with_quant = True
|
|
hidden_states = torch.randn(10, 128)
|
|
topk_ids = torch.randint(0, 8, (10, 1))
|
|
topk_weights = torch.randn(10, 1)
|
|
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
|
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
|
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
|
mc2_mask = None
|
|
assist_info_for_combine = torch.arange(10)
|
|
|
|
context_metadata = {
|
|
"topk_ids": topk_ids,
|
|
"topk_weights": topk_weights,
|
|
"expert_map": expert_map,
|
|
"ep_recv_counts": ep_recv_counts,
|
|
"mc2_mask": mc2_mask,
|
|
"assist_info_for_combine": assist_info_for_combine,
|
|
"expand_scales": None,
|
|
"tp_recv_counts": tp_recv_counts
|
|
}
|
|
|
|
self.dispatcher.need_extra_args = True
|
|
self.dispatcher.enable_dispatch_v2 = True
|
|
|
|
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
|
context_metadata)
|
|
self.assertIn("tp_send_counts", kwargs)
|
|
|
|
|
|
class TestTokenDispatcherWithAllGather(TestBase):
|
|
|
|
def setUp(self):
|
|
# Mock dependencies
|
|
kwargs = {
|
|
"apply_router_weight_on_input": False,
|
|
"top_k": 2,
|
|
"max_num_tokens": 100,
|
|
"ep_size": 2,
|
|
"num_experts": 128,
|
|
"with_quant": False,
|
|
}
|
|
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
|
|
|
# Mock NPU functions
|
|
self.patcher_npu_moe_init_routing_v2 = patch(
|
|
'torch_npu.npu_moe_init_routing_v2')
|
|
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
|
|
)
|
|
self.mock_npu_moe_init_routing_v2.return_value = (
|
|
torch.randn(6, 128), # sorted_hidden_states
|
|
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
|
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
|
torch.tensor([0, 1, 0, 1, 0, 1]))
|
|
self.patcher_npu_moe_token_unpermute = patch(
|
|
'torch_npu.npu_moe_token_unpermute')
|
|
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
|
)
|
|
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
|
|
|
|
def tearDown(self):
|
|
self.patcher_npu_moe_init_routing_v2.stop()
|
|
self.patcher_npu_moe_token_unpermute.stop()
|
|
|
|
def test_token_dispatch_without_expert_map(self):
|
|
hidden_states = torch.randn(3, 128)
|
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
|
|
|
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
|
topk_ids, None)
|
|
|
|
# Verify npu_moe_init_routing is called
|
|
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
|
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
|
|
|
self.assertEqual(results.group_list_type, 1)
|
|
|
|
def test_token_dispatch_with_expert_map(self):
|
|
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
|
hidden_states = torch.randn(3, 128)
|
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
|
|
|
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
|
topk_ids, None)
|
|
|
|
# Verify npu_moe_init_routing is called
|
|
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
|
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
|
|
|
self.assertEqual(results.group_list_type, 1)
|
|
|
|
def test_token_dispatch_without_quant(self):
|
|
kwargs = {
|
|
"apply_router_weight_on_input": False,
|
|
"top_k": 2,
|
|
"max_num_tokens": 100,
|
|
"ep_size": 2,
|
|
"num_experts": 128,
|
|
}
|
|
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
|
|
|
hidden_states = torch.randn(3, 128)
|
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
|
|
|
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
|
topk_weights, topk_ids,
|
|
None)
|
|
|
|
self.assertEqual(results.group_list_type, 1)
|
|
|
|
def test_token_dispatch_with_quant(self):
|
|
kwargs = {
|
|
"apply_router_weight_on_input": False,
|
|
"top_k": 2,
|
|
"max_num_tokens": 100,
|
|
"ep_size": 2,
|
|
"num_experts": 128,
|
|
}
|
|
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
|
|
|
hidden_states = torch.randn(3, 128)
|
|
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
|
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
|
|
|
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
|
topk_weights,
|
|
topk_ids,
|
|
None,
|
|
with_quant=True)
|
|
|
|
self.assertIsNotNone(results.hidden_states)
|
|
self.assertIsNotNone(results.group_list)
|
|
self.assertIsNotNone(results.dynamic_scale)
|
|
self.assertEqual(results.group_list_type, 1)
|
|
|
|
def test_token_combine_with_expert_map(self):
|
|
hidden_states = torch.randn(6, 128)
|
|
context_metadata = {
|
|
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
|
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
|
}
|
|
self.dispatcher.original_shape = (6, 128)
|
|
final_hidden_states = self.dispatcher.token_combine(
|
|
hidden_states, context_metadata).routed_out
|
|
self.assertEqual(final_hidden_states.shape, (6, 128))
|
|
|
|
def test_token_combine_without_expert_map(self):
|
|
hidden_states = torch.randn(6, 128)
|
|
context_metadata = {
|
|
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
|
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
|
}
|
|
self.dispatcher.original_shape = (6, 128)
|
|
final_hidden_states = self.dispatcher.token_combine(
|
|
hidden_states, context_metadata).routed_out
|
|
self.mock_npu_moe_token_unpermute.assert_called_once()
|
|
self.assertEqual(final_hidden_states.shape, (6, 128))
|
|
|
|
def test_token_dispatch_with_router_weight(self):
|
|
self.dispatcher.apply_router_weight_on_input = True
|
|
hidden_states = torch.randn(3, 128)
|
|
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
|
topk_ids = torch.tensor([[0], [1], [2]])
|
|
|
|
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
|
topk_ids, None)
|
|
self.assertEqual(results.hidden_states.shape, (6, 128))
|
|
|
|
|
|
class TestTokenDispatcherWithAll2AllV(TestBase):
|
|
|
|
def setUp(self):
|
|
# Patch properties
|
|
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
|
|
'ep_group',
|
|
new_callable=PropertyMock,
|
|
return_value=MagicMock())
|
|
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
|
|
'ep_rank',
|
|
new_callable=PropertyMock,
|
|
return_value=0)
|
|
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
|
|
'ep_size',
|
|
new_callable=PropertyMock,
|
|
return_value=2)
|
|
|
|
self.addCleanup(patcher1.stop)
|
|
self.addCleanup(patcher2.stop)
|
|
self.addCleanup(patcher3.stop)
|
|
|
|
self.mock_ep_group_prop = patcher1.start()
|
|
self.mock_ep_rank_prop = patcher2.start()
|
|
self.mock_ep_size_prop = patcher3.start()
|
|
|
|
# Mock torch_npu.npu_moe_token_permute
|
|
patcher4 = patch('torch_npu.npu_moe_token_permute')
|
|
self.mock_npu_moe_token_permute = patcher4.start()
|
|
self.addCleanup(patcher4.stop)
|
|
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
|
|
torch.arange(16))
|
|
|
|
# Mock torch_npu.npu_moe_token_unpermute
|
|
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
|
|
self.mock_npu_moe_token_unpermute = patcher5.start()
|
|
self.addCleanup(patcher5.stop)
|
|
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
|
|
|
# Mock async_all_to_all
|
|
patcher6 = patch(
|
|
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
|
|
self.mock_async_all_to_all = patcher6.start()
|
|
self.addCleanup(patcher6.stop)
|
|
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
|
MagicMock())
|
|
|
|
# Mock gather_from_sequence_parallel_region
|
|
patcher7 = patch(
|
|
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
|
|
)
|
|
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
|
self.addCleanup(patcher7.stop)
|
|
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
|
|
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
|
|
|
|
# Mock torch.histc
|
|
patcher8 = patch('torch.histc')
|
|
self.mock_histc = patcher8.start()
|
|
self.addCleanup(patcher8.stop)
|
|
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
|
|
dtype=torch.int64)
|
|
|
|
# Mock torch.npu.current_device
|
|
patcher9 = patch('torch.npu.current_device')
|
|
self.mock_current_device = patcher9.start()
|
|
self.addCleanup(patcher9.stop)
|
|
self.mock_current_device.return_value = 'cpu'
|
|
|
|
# Mock torch_npu.npu_dynamic_quant
|
|
patcher10 = patch('torch_npu.npu_dynamic_quant')
|
|
self.mock_npu_dynamic_quant = patcher10.start()
|
|
self.addCleanup(patcher10.stop)
|
|
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
|
|
torch.randn(16))
|
|
|
|
# Mock torch_npu.npu_moe_init_routing_v2
|
|
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
|
|
self.mock_npu_moe_init_routing_v2 = patcher11.start()
|
|
self.addCleanup(patcher11.stop)
|
|
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
|
|
16, 16), torch.arange(16), None, torch.randn(16))
|
|
|
|
# Mock torch.repeat_interleave
|
|
patcher12 = patch('torch.repeat_interleave')
|
|
self.mock_repeat_interleave = patcher12.start()
|
|
self.addCleanup(patcher12.stop)
|
|
self.mock_repeat_interleave.return_value = torch.arange(16)
|
|
|
|
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
|
num_experts=4,
|
|
num_local_experts=2,
|
|
with_quant=False)
|
|
|
|
def test_token_dispatch(self):
|
|
hidden_states = torch.randn(8, 16)
|
|
topk_weights = torch.rand(8, 4)
|
|
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
|
expert_map = torch.tensor([0, 1, 2, 3])
|
|
|
|
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
|
[0, 1], dtype=torch.int32)
|
|
self.dispatcher.local_expert_indices = [0, 1]
|
|
|
|
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
expert_map=expert_map)
|
|
|
|
self.assertIsNotNone(result.hidden_states)
|
|
self.assertIsNotNone(result.group_list)
|
|
self.assertEqual(result.group_list_type, 1)
|
|
|
|
def test_token_combine(self):
|
|
hidden_states = torch.randn(16, 16)
|
|
context_metadata = {
|
|
"input_splits": [4, 4],
|
|
"output_splits": [4, 4],
|
|
"topk_weights": torch.rand(8, 4),
|
|
"reversed_local_input_permutation_mapping": torch.arange(8),
|
|
"reversed_global_input_permutation_mapping": torch.arange(16),
|
|
}
|
|
self.dispatcher.hidden_shape = (8, 16)
|
|
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
|
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
|
[0, 1], dtype=torch.int32)
|
|
self.dispatcher.local_expert_indices = [0, 1]
|
|
|
|
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
|
self.assertIsNotNone(output)
|
|
self.assertEqual(output.routed_out.shape, (8, 16))
|
|
|
|
def test_token_dispatch_with_quant(self):
|
|
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
|
num_experts=4,
|
|
num_local_experts=2)
|
|
|
|
hidden_states = torch.randn(8, 16)
|
|
topk_weights = torch.rand(8, 4)
|
|
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
|
expert_map = torch.tensor([0, 1, 2, 3])
|
|
|
|
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
|
[0, 1], dtype=torch.int32)
|
|
self.dispatcher.local_expert_indices = [0, 1]
|
|
|
|
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
expert_map=expert_map,
|
|
with_quant=True)
|
|
|
|
self.assertIsNotNone(result.hidden_states)
|
|
self.assertIsNotNone(result.group_list)
|
|
self.assertIsNotNone(result.dynamic_scale)
|
|
self.assertEqual(result.group_list_type, 1)
|
|
|
|
def test_token_dispatch_with_quant_no_active_tokens(self):
|
|
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
|
num_experts=4,
|
|
num_local_experts=2)
|
|
|
|
self.mock_repeat_interleave.return_value = torch.tensor(
|
|
[], dtype=torch.long)
|
|
|
|
hidden_states = torch.randn(8, 16)
|
|
topk_weights = torch.rand(8, 4)
|
|
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
|
expert_map = torch.tensor([0, 1, 2, 3])
|
|
|
|
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
|
[0, 1], dtype=torch.int32)
|
|
self.dispatcher.local_expert_indices = [0, 1]
|
|
|
|
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
expert_map=expert_map,
|
|
with_quant=True)
|
|
|
|
self.assertIsNotNone(result.hidden_states)
|
|
self.assertIsNotNone(result.group_list)
|
|
self.assertIsNotNone(result.dynamic_scale)
|
|
self.assertEqual(result.group_list_type, 1)
|
|
|
|
def test_token_dispatch_with_log2phy(self):
|
|
hidden_states = torch.randn(8, 16)
|
|
topk_weights = torch.rand(8, 4)
|
|
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
|
expert_map = torch.tensor([0, 1, 2, 3])
|
|
log2phy = torch.tensor([1, 0, 3, 2])
|
|
|
|
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
|
[0, 1], dtype=torch.int32)
|
|
self.dispatcher.local_expert_indices = [0, 1]
|
|
|
|
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
expert_map=expert_map,
|
|
log2phy=log2phy)
|
|
|
|
self.assertIsNotNone(result.hidden_states)
|
|
self.assertIsNotNone(result.group_list)
|
|
self.assertEqual(result.group_list_type, 1)
|