[main] refactor alltoallv in fused_moe (#2487)
### What this PR does / why we need it?
Refactor all2all-related fused_experts (both quantized/unquantized) into
TokenDispatcherWithAll2AllV, including dispatch & combine calculation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
E2E & UT
- vLLM version: v0.10.0
- vLLM main:
65197a5fb3
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -15,17 +15,17 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from tests.ut.base import PytestBase, TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2)
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
|
||||
|
||||
@@ -70,40 +70,40 @@ class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
|
||||
assert dispatcher.overlap_stream is not None
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mc2_group = mock.MagicMock()
|
||||
self.mc2_group = MagicMock()
|
||||
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = mock.patch(
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
self.rank_group_patch = mock.patch("torch.distributed.get_rank",
|
||||
return_value=0)
|
||||
self.rank_group_patch = patch("torch.distributed.get_rank",
|
||||
return_value=0)
|
||||
self.rank_group_patch.start()
|
||||
|
||||
# Mock get_forward_context().mc2_mask
|
||||
self.forward_context = mock.MagicMock()
|
||||
self.forward_context = MagicMock()
|
||||
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
self.forward_context_patch = mock.patch(
|
||||
self.forward_context_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
|
||||
return_value=self.forward_context)
|
||||
self.forward_context_patch.start()
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = mock.patch(
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
# Mock get_ascend_config()
|
||||
self.ascend_config = mock.MagicMock()
|
||||
self.ascend_config = MagicMock()
|
||||
self.ascend_config.torchair_graph_config.enabled = False
|
||||
self.ascend_config_patch = mock.patch(
|
||||
self.ascend_config_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
|
||||
return_value=self.ascend_config)
|
||||
self.ascend_config_patch.start()
|
||||
@@ -127,13 +127,13 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
self.assertTrue(self.dispatcher.a3_need_extra_args)
|
||||
|
||||
def test_get_permute_mc2_kwargs_without_quant(self):
|
||||
def test_get_dispatch_mc2_kwargs_without_quant(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
kwargs = self.dispatcher.get_permute_mc2_kwargs(
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
@@ -145,17 +145,16 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) *
|
||||
5) as mock_dispatch:
|
||||
output = self.dispatcher.token_permutation(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output[0], 1) # group_list_type == 1
|
||||
|
||||
def test_token_permutation_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = mock.MagicMock()
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
|
||||
@@ -165,15 +164,15 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
self.topk_weights = torch.randn(10, 1)
|
||||
|
||||
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
with mock.patch(
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
|
||||
autospec=True):
|
||||
with mock.patch(
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
|
||||
autospec=True) as mock_wait:
|
||||
self.dispatcher.token_permutation(
|
||||
self.dispatcher.token_dispatch(
|
||||
self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
@@ -182,7 +181,7 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
mock_wait.assert_any_call(self.hidden_states,
|
||||
self.topk_weights)
|
||||
|
||||
def test_get_unpermute_mc_kwargs_with_quant(self):
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
@@ -193,11 +192,11 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
|
||||
kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states)
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_unpermutation_with_shared_experts(self):
|
||||
self.dispatcher.shared_experts = mock.MagicMock()
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
self.dispatcher.shared_experts = MagicMock()
|
||||
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
|
||||
10, 128), torch.tensor(1.0))
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
@@ -212,18 +211,18 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
|
||||
with mock.patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
with mock.patch(
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
|
||||
autospec=True):
|
||||
with mock.patch(
|
||||
with patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
|
||||
autospec=True):
|
||||
self.dispatcher.token_unpermutation(self.hidden_states)
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock dependencies
|
||||
@@ -238,8 +237,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = mock.patch(
|
||||
'torch_npu.npu_moe_init_routing')
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
@@ -247,14 +245,14 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = mock.patch(
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = mock.patch(
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
@@ -265,12 +263,12 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
|
||||
def test_token_permutation_without_expert_map(self):
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
@@ -279,7 +277,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
|
||||
self.assertEqual(group_list_type, 0)
|
||||
|
||||
def test_token_permutation_with_quant(self):
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
@@ -294,13 +292,13 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation(
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
|
||||
# Verify quant mode returns group_list_type=1
|
||||
self.assertEqual(group_list_type, 0)
|
||||
|
||||
def test_token_unpermutation_with_expert_map(self):
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
@@ -309,13 +307,12 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_unpermutation(
|
||||
hidden_states)
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_unpermutation_without_expert_map(self):
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
@@ -326,8 +323,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_unpermutation(
|
||||
hidden_states)
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
@@ -335,22 +331,231 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_permutation_with_router_weight(self):
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
|
||||
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
|
||||
hidden_states, topk_weights, topk_ids, None)
|
||||
self.assertEqual(sorted_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_permutation_invalid_topk_when_router_weight(self):
|
||||
def test_token_dispatch_invalid_topk_when_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
self.dispatcher.token_permutation(
|
||||
self.dispatcher.token_dispatch(
|
||||
hidden_states, topk_weights,
|
||||
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Patch properties
|
||||
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_group',
|
||||
new_callable=PropertyMock,
|
||||
return_value=MagicMock())
|
||||
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_rank',
|
||||
new_callable=PropertyMock,
|
||||
return_value=0)
|
||||
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_size',
|
||||
new_callable=PropertyMock,
|
||||
return_value=2)
|
||||
|
||||
self.addCleanup(patcher1.stop)
|
||||
self.addCleanup(patcher2.stop)
|
||||
self.addCleanup(patcher3.stop)
|
||||
|
||||
self.mock_ep_group_prop = patcher1.start()
|
||||
self.mock_ep_rank_prop = patcher2.start()
|
||||
self.mock_ep_size_prop = patcher3.start()
|
||||
|
||||
# Mock torch_npu.npu_moe_token_permute
|
||||
patcher4 = patch('torch_npu.npu_moe_token_permute')
|
||||
self.mock_npu_moe_token_permute = patcher4.start()
|
||||
self.addCleanup(patcher4.stop)
|
||||
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
|
||||
torch.arange(16))
|
||||
|
||||
# Mock torch_npu.npu_moe_token_unpermute
|
||||
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = patcher5.start()
|
||||
self.addCleanup(patcher5.stop)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
MagicMock())
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
|
||||
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
|
||||
|
||||
# Mock torch.histc
|
||||
patcher8 = patch('torch.histc')
|
||||
self.mock_histc = patcher8.start()
|
||||
self.addCleanup(patcher8.stop)
|
||||
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
|
||||
dtype=torch.int64)
|
||||
|
||||
# Mock torch.npu.current_device
|
||||
patcher9 = patch('torch.npu.current_device')
|
||||
self.mock_current_device = patcher9.start()
|
||||
self.addCleanup(patcher9.stop)
|
||||
self.mock_current_device.return_value = 'cpu'
|
||||
|
||||
# Mock torch_npu.npu_dynamic_quant
|
||||
patcher10 = patch('torch_npu.npu_dynamic_quant')
|
||||
self.mock_npu_dynamic_quant = patcher10.start()
|
||||
self.addCleanup(patcher10.stop)
|
||||
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
|
||||
torch.randn(16))
|
||||
|
||||
# Mock torch_npu.npu_moe_init_routing_v2
|
||||
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = patcher11.start()
|
||||
self.addCleanup(patcher11.stop)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
|
||||
16, 16), torch.arange(16), None, torch.randn(16))
|
||||
|
||||
# Mock torch.repeat_interleave
|
||||
patcher12 = patch('torch.repeat_interleave')
|
||||
self.mock_repeat_interleave = patcher12.start()
|
||||
self.addCleanup(patcher12.stop)
|
||||
self.mock_repeat_interleave.return_value = torch.arange(16)
|
||||
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=False)
|
||||
|
||||
def test_token_dispatch(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
|
||||
8)
|
||||
self.dispatcher.topk_weights = torch.rand(8, 4)
|
||||
self.dispatcher.input_splits = [4, 4]
|
||||
self.dispatcher.output_splits = [4, 4]
|
||||
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
|
||||
16)
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
|
||||
[[2, 2], [2, 2]], dtype=torch.int64)
|
||||
|
||||
expert_output = torch.randn(16, 16)
|
||||
output = self.dispatcher.token_combine(expert_output)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=True)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=True)
|
||||
|
||||
self.mock_repeat_interleave.return_value = torch.tensor(
|
||||
[], dtype=torch.long)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_log2phy(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
log2phy = torch.tensor([1, 0, 3, 2])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
Reference in New Issue
Block a user