Files
xc-llm-ascend/tests/ut/ops/test_token_dispatcher.py

582 lines
24 KiB
Python
Raw Permalink Normal View History

2025-08-02 09:49:10 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock, PropertyMock, patch
import numpy as np
import pytest
import torch
2025-08-02 09:49:10 +08:00
from tests.ut.base import TestBase
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEMC2CombineMetadata,
MoEQuantParams,
MoERoutingParams,
MoETokenDispatchInput,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
AscendDeviceType,
TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
from vllm_ascend.quantization.quant_type import QuantType
def build_token_dispatch_input_fixture(
*,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
apply_router_weight_on_input: bool = False,
pertoken_scale: torch.Tensor | None = None,
quant_type: QuantType = QuantType.NONE,
comm_quant_mode: int | None = None,
act_quant_type: torch.dtype | None = None,
) -> MoETokenDispatchInput:
mxfp_spec = None
if quant_type == QuantType.MXFP8:
mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type)
return MoETokenDispatchInput(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
pertoken_scale=pertoken_scale,
),
quant=MoEQuantParams(
quant_type=quant_type,
comm_quant_mode=comm_quant_mode,
mxfp=mxfp_spec,
),
)
class TestTokenDispatcherWithMC2(TestBase):
def setUp(self):
self.config_patcher = patch(
'vllm_ascend.ops.fused_moe.token_dispatcher.get_current_vllm_config'
)
self.mock_get_config = self.config_patcher.start()
mock_config = MagicMock()
mock_config.scheduler_config.max_num_seqs = 256
mock_config.scheduler_config.decode_max_num_seqs = 256
mock_config.compilation_config.custom_ops = ["all"]
mock_config.speculative_config = None
mock_config.parallel_config.tensor_parallel_size = 1
self.mock_get_config.return_value = mock_config
self.mc2_group = MagicMock()
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
self.rank_group_patch = patch("torch.distributed.get_rank",
return_value=0)
self.rank_group_patch.start()
# Mock get_forward_context().mc2_mask
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch(
"vllm.forward_context.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
[refact] unified soc_version code (#4359) ### What this PR does / why we need it? Currently, there are two paths to judge the chip type in code, `get_ascend_soc_version` use `get_soc_version` api in torch_npu, and `is_310p` `use _build_info.__soc_version__`, which generate when install. We need to unify the two paths. We need to unify these codes based on the following points: 1. We need to ensure consistency in chip type judgment between compiling and running states; 2. In compiling state, we need chip type to complete op's compilation, but in running state, we only need device type(910B/910_93/310P/910_95/etc) to make code branch judgement; 3. In compiling state, torch_npu may not have been installed yet, so we can't use torch_npu's api. Based on the above points, we have made the following changes: 1. When user set env `SOC_VERSION`, use it; when not set, query soc_version by `npu-smi`; 2. generate device_type based on soc_version when compiling, and write `__device_type__` instead of `__soc_version__` in `_build_info.py`; 3. In running state, use `__device_type__` to judge code branch. ### Does this PR introduce _any_ user-facing change? When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default, we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in the list `soc_to_device` in `setup.py`. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
# Mock get_ascend_device_type()
self.ascend_soc_version_patch = patch(
[refact] unified soc_version code (#4359) ### What this PR does / why we need it? Currently, there are two paths to judge the chip type in code, `get_ascend_soc_version` use `get_soc_version` api in torch_npu, and `is_310p` `use _build_info.__soc_version__`, which generate when install. We need to unify the two paths. We need to unify these codes based on the following points: 1. We need to ensure consistency in chip type judgment between compiling and running states; 2. In compiling state, we need chip type to complete op's compilation, but in running state, we only need device type(910B/910_93/310P/910_95/etc) to make code branch judgement; 3. In compiling state, torch_npu may not have been installed yet, so we can't use torch_npu's api. Based on the above points, we have made the following changes: 1. When user set env `SOC_VERSION`, use it; when not set, query soc_version by `npu-smi`; 2. generate device_type based on soc_version when compiling, and write `__device_type__` instead of `__soc_version__` in `_build_info.py`; 3. In running state, use `__device_type__` to judge code branch. ### Does this PR introduce _any_ user-facing change? When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default, we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in the list `soc_to_device` in `setup.py`. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_device_type",
return_value=AscendDeviceType.A3)
self.ascend_soc_version_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
def tearDown(self):
self.mc2_group_patch.stop()
self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop()
def test_init(self):
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
def test_get_dispatch_mc2_kwargs_without_quant(self):
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
global_redundant_expert_num=0,
apply_router_weight_on_input=False,
pertoken_scale=None,
)
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
def test_token_permutation_dispatch(self):
hidden_states = torch.randn(10, 128)
topk_weights = torch.randn(10, 1)
topk_ids = torch.randint(0, 8, (10, 1))
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 +
(None, None)) as mock_dispatch:
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
)
output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
mock_dispatch.assert_called_once()
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata)
def test_get_combine_mc_kwargs_with_quant(self):
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
combine_metadata = MoEMC2CombineMetadata(
topk_ids=topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
ep_recv_counts=ep_recv_counts,
tp_recv_counts=tp_recv_counts,
assist_info_for_combine=assist_info_for_combine,
expand_scales=None,
dispatch_with_quant=True,
)
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
Bugfix: Align expert map shapes with redundant experts in EPLB adjustment (#5285) #### Overview This PR fixes a shape mismatch bug between `expert_placement_map` and `log2phy_expert_map` when **redundant experts** are enabled in the vLLM-Ascend platform. The issue occurred during the initialization of expert maps and their updates via EPLB (Expert Load Balancer) adjustment, leading to potential tensor shape errors and incorrect expert routing in distributed MoE deployments. #### Key Changes 1. **Unify expert map shape calculation logic** - Ensure the shape of `expert_placement_map` and `log2phy_expert_map` strictly aligns with the total number of experts (including redundant experts) during initialization. - Update the shape adjustment logic in EPLB dynamic update process to match the initial expert map dimensions. 2. **Add shape consistency checks** - Add assertion statements to verify the shape consistency of the two maps after initialization and EPLB adjustment, preventing silent shape mismatches in subsequent operations. #### Impact - Resolves tensor shape errors when using redundant experts with EPLB on Ascend platform. - Ensures correct expert routing and load balancing for MoE models with redundant expert configurations. - No breaking changes to existing functionality; compatible with non-redundant expert deployments. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Che Ruan <cr623@ic.ac.uk> Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: Che Ruan <cr623@ic.ac.uk> Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
2026-01-06 17:22:36 +08:00
self.dispatcher.moe_expert_num = len(expert_map)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
combine_metadata)
self.assertIn("tp_send_counts", kwargs)
class TestTokenDispatcherWithAllGather(TestBase):
def setUp(self):
# Mock dependencies
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
"with_quant": False,
}
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_npu_moe_init_routing_custom = patch(
'torch.ops._C_ascend.npu_moe_init_routing_custom')
self.mock_npu_moe_init_routing_custom = self.patcher_npu_moe_init_routing_custom.start(
)
self.mock_npu_moe_init_routing_custom.return_value = (
torch.randn(6, 128), # sorted_hidden_states
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
def tearDown(self):
self.patcher_npu_moe_init_routing_custom.stop()
self.patcher_npu_moe_token_unpermute.stop()
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_custom.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
self.assertEqual(results.group_list_type, 1)
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_custom.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
self.assertEqual(results.group_list_type, 1)
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_without_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertEqual(results.group_list_type, 1)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type=QuantType.W8A8,
)
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list)
self.assertIsNotNone(results.dynamic_scale)
self.assertEqual(results.group_list_type, 1)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128)
combine_metadata = MoEAllGatherCombineMetadata(
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
restore_shape=torch.Size([6, 128]),
)
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.assertEqual(final_hidden_states.shape, (6, 128))
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine_without_expert_map(self):
hidden_states = torch.randn(6, 128)
combine_metadata = MoEAllGatherCombineMetadata(
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
restore_shape=torch.Size([6, 128]),
)
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128))
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_router_weight(self):
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=True,
)
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertEqual(results.hidden_states.shape, (6, 128))
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
class TestTokenDispatcherWithAll2AllV(TestBase):
def setUp(self):
# Patch properties
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
'ep_group',
new_callable=PropertyMock,
return_value=MagicMock())
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
'ep_rank',
new_callable=PropertyMock,
return_value=0)
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
'ep_size',
new_callable=PropertyMock,
return_value=2)
self.addCleanup(patcher1.stop)
self.addCleanup(patcher2.stop)
self.addCleanup(patcher3.stop)
self.mock_ep_group_prop = patcher1.start()
self.mock_ep_rank_prop = patcher2.start()
self.mock_ep_size_prop = patcher3.start()
# Mock torch_npu.npu_moe_token_permute
patcher4 = patch('torch_npu.npu_moe_token_permute')
self.mock_npu_moe_token_permute = patcher4.start()
self.addCleanup(patcher4.stop)
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
torch.arange(16))
# Mock torch_npu.npu_moe_token_unpermute
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = patcher5.start()
self.addCleanup(patcher5.stop)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch(
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
MagicMock())
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
# Mock torch.histc
patcher8 = patch('torch.histc')
self.mock_histc = patcher8.start()
self.addCleanup(patcher8.stop)
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
dtype=torch.int64)
# Mock torch.npu.current_device
patcher9 = patch('torch.npu.current_device')
self.mock_current_device = patcher9.start()
self.addCleanup(patcher9.stop)
self.mock_current_device.return_value = 'cpu'
# Mock torch_npu.npu_dynamic_quant
patcher10 = patch('torch_npu.npu_dynamic_quant')
self.mock_npu_dynamic_quant = patcher10.start()
self.addCleanup(patcher10.stop)
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
torch.randn(16))
# Mock torch.ops._C_ascend.npu_moe_init_routing_custom
patcher11 = patch('torch.ops._C_ascend.npu_moe_init_routing_custom')
self.mock_npu_moe_init_routing_custom = patcher11.start()
self.addCleanup(patcher11.stop)
self.mock_npu_moe_init_routing_custom.return_value = (torch.randn(
16, 16), torch.arange(16), None, torch.randn(16))
# Mock torch.repeat_interleave
patcher12 = patch('torch.repeat_interleave')
self.mock_repeat_interleave = patcher12.start()
self.addCleanup(patcher12.stop)
self.mock_repeat_interleave.return_value = torch.arange(16)
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=False)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
combine_metadata = MoEAllToAllCombineMetadata(
input_splits=np.array([4, 4]),
output_splits=np.array([4, 4]),
topk_weights=torch.rand(8, 4),
reversed_local_input_permutation_mapping=torch.arange(8),
reversed_global_input_permutation_mapping=torch.arange(16),
hidden_shape=torch.Size([8, 16]),
hidden_shape_before_permute=torch.Size([8, 16]),
)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
output = self.dispatcher.token_combine(hidden_states, combine_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_type=QuantType.W8A8,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
@pytest.mark.skip(
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
self.mock_repeat_interleave.return_value = torch.tensor(
[], dtype=torch.long)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
token_dispatch_input = build_token_dispatch_input_fixture(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_type=QuantType.W8A8,
)
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)