[main] refactor alltoallv in fused_moe (#2487)

### What this PR does / why we need it?
Refactor all2all-related fused_experts (both quantized/unquantized) into
TokenDispatcherWithAll2AllV, including dispatch & combine calculation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
E2E & UT
- vLLM version: v0.10.0
- vLLM main:
65197a5fb3

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-08-23 20:38:17 +08:00
committed by GitHub
parent 4af5b80606
commit 950c4b219a
2 changed files with 560 additions and 107 deletions

View File

@@ -15,17 +15,17 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
import unittest
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
import torch
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from tests.ut.base import PytestBase, TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
from vllm_ascend.utils import adapt_patch # noqa E402
@@ -70,40 +70,40 @@ class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
assert dispatcher.overlap_stream is not None
class TestTokenDispatcherWithMC2(unittest.TestCase):
class TestTokenDispatcherWithMC2(TestBase):
def setUp(self):
self.mc2_group = mock.MagicMock()
self.mc2_group = MagicMock()
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = mock.patch(
self.mc2_group_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
self.rank_group_patch = mock.patch("torch.distributed.get_rank",
return_value=0)
self.rank_group_patch = patch("torch.distributed.get_rank",
return_value=0)
self.rank_group_patch.start()
# Mock get_forward_context().mc2_mask
self.forward_context = mock.MagicMock()
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = mock.patch(
self.forward_context_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
# Mock get_ascend_soc_version()
self.ascend_soc_version_patch = mock.patch(
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
# Mock get_ascend_config()
self.ascend_config = mock.MagicMock()
self.ascend_config = MagicMock()
self.ascend_config.torchair_graph_config.enabled = False
self.ascend_config_patch = mock.patch(
self.ascend_config_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config",
return_value=self.ascend_config)
self.ascend_config_patch.start()
@@ -127,13 +127,13 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
self.assertTrue(self.dispatcher.need_extra_args)
self.assertTrue(self.dispatcher.a3_need_extra_args)
def test_get_permute_mc2_kwargs_without_quant(self):
def test_get_dispatch_mc2_kwargs_without_quant(self):
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
kwargs = self.dispatcher.get_permute_mc2_kwargs(
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
@@ -145,17 +145,16 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
topk_ids = torch.randint(0, 8, (10, 1))
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) *
5) as mock_dispatch:
output = self.dispatcher.token_permutation(hidden_states,
topk_weights, topk_ids,
expert_map)
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output[0], 1) # group_list_type == 1
def test_token_permutation_with_shared_experts_and_quant(self):
self.shared_experts = mock.MagicMock()
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
@@ -165,15 +164,15 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
self.hidden_states = torch.randn(10, 128)
self.topk_weights = torch.randn(10, 1)
with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5):
with mock.patch(
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with mock.patch(
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True) as mock_wait:
self.dispatcher.token_permutation(
self.dispatcher.token_dispatch(
self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
@@ -182,7 +181,7 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
mock_wait.assert_any_call(self.hidden_states,
self.topk_weights)
def test_get_unpermute_mc_kwargs_with_quant(self):
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
@@ -193,11 +192,11 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))
kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
self.assertIn("tp_send_counts", kwargs)
def test_token_unpermutation_with_shared_experts(self):
self.dispatcher.shared_experts = mock.MagicMock()
def test_token_combine_with_shared_experts(self):
self.dispatcher.shared_experts = MagicMock()
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
10, 128), torch.tensor(1.0))
self.dispatcher.shared_act = torch.randn(10, 128)
@@ -212,18 +211,18 @@ class TestTokenDispatcherWithMC2(unittest.TestCase):
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.hidden_states = torch.randn(10, 128)
with mock.patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
with mock.patch(
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch",
autospec=True):
with mock.patch(
with patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor",
autospec=True):
self.dispatcher.token_unpermutation(self.hidden_states)
self.dispatcher.token_combine(self.hidden_states)
class TestTokenDispatcherWithAllGather(unittest.TestCase):
class TestTokenDispatcherWithAllGather(TestBase):
def setUp(self):
# Mock dependencies
@@ -238,8 +237,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_moe_init_routing = mock.patch(
'torch_npu.npu_moe_init_routing')
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
self.mock_moe_init_routing.return_value = (
torch.randn(6, 128), # sorted_hidden_states
@@ -247,14 +245,14 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
)
self.patcher_moe_compute_expert_tokens = mock.patch(
self.patcher_moe_compute_expert_tokens = patch(
'torch_npu.npu_moe_compute_expert_tokens')
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
)
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
[3, 3]) # expert_tokens
self.patcher_moe_finalize_routing = mock.patch(
self.patcher_moe_finalize_routing = patch(
'torch_npu.npu_moe_finalize_routing')
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
@@ -265,12 +263,12 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
self.patcher_moe_compute_expert_tokens.stop()
self.patcher_moe_finalize_routing.stop()
def test_token_permutation_without_expert_map(self):
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
# Verify npu_moe_init_routing is called
@@ -279,7 +277,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
self.assertEqual(group_list_type, 0)
def test_token_permutation_with_quant(self):
def test_token_dispatch_with_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
@@ -294,13 +292,13 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation(
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
# Verify quant mode returns group_list_type=1
self.assertEqual(group_list_type, 0)
def test_token_unpermutation_with_expert_map(self):
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
@@ -309,13 +307,12 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_unpermutation(
hidden_states)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify index_add_ is applied correctly
self.assertEqual(final_hidden_states.shape, (3, 128))
def test_token_unpermutation_without_expert_map(self):
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
@@ -326,8 +323,7 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_unpermutation(
hidden_states)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
self.mock_moe_finalize_routing.assert_called_once()
@@ -335,22 +331,231 @@ class TestTokenDispatcherWithAllGather(unittest.TestCase):
self.assertEqual(final_hidden_states.shape, (3, 128))
def test_token_permutation_with_router_weight(self):
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation(
group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch(
hidden_states, topk_weights, topk_ids, None)
self.assertEqual(sorted_hidden_states.shape, (6, 128))
def test_token_permutation_invalid_topk_when_router_weight(self):
def test_token_dispatch_invalid_topk_when_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
with self.assertRaises(AssertionError):
self.dispatcher.token_permutation(
self.dispatcher.token_dispatch(
hidden_states, topk_weights,
torch.tensor([[0, 1], [1, 2], [2, 3]]), None)
class TestTokenDispatcherWithAll2AllV(TestBase):
def setUp(self):
# Patch properties
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
'ep_group',
new_callable=PropertyMock,
return_value=MagicMock())
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
'ep_rank',
new_callable=PropertyMock,
return_value=0)
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
'ep_size',
new_callable=PropertyMock,
return_value=2)
self.addCleanup(patcher1.stop)
self.addCleanup(patcher2.stop)
self.addCleanup(patcher3.stop)
self.mock_ep_group_prop = patcher1.start()
self.mock_ep_rank_prop = patcher2.start()
self.mock_ep_size_prop = patcher3.start()
# Mock torch_npu.npu_moe_token_permute
patcher4 = patch('torch_npu.npu_moe_token_permute')
self.mock_npu_moe_token_permute = patcher4.start()
self.addCleanup(patcher4.stop)
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
torch.arange(16))
# Mock torch_npu.npu_moe_token_unpermute
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = patcher5.start()
self.addCleanup(patcher5.stop)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
MagicMock())
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
# Mock torch.histc
patcher8 = patch('torch.histc')
self.mock_histc = patcher8.start()
self.addCleanup(patcher8.stop)
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
dtype=torch.int64)
# Mock torch.npu.current_device
patcher9 = patch('torch.npu.current_device')
self.mock_current_device = patcher9.start()
self.addCleanup(patcher9.stop)
self.mock_current_device.return_value = 'cpu'
# Mock torch_npu.npu_dynamic_quant
patcher10 = patch('torch_npu.npu_dynamic_quant')
self.mock_npu_dynamic_quant = patcher10.start()
self.addCleanup(patcher10.stop)
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
torch.randn(16))
# Mock torch_npu.npu_moe_init_routing_v2
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = patcher11.start()
self.addCleanup(patcher11.stop)
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
16, 16), torch.arange(16), None, torch.randn(16))
# Mock torch.repeat_interleave
patcher12 = patch('torch.repeat_interleave')
self.mock_repeat_interleave = patcher12.start()
self.addCleanup(patcher12.stop)
self.mock_repeat_interleave.return_value = torch.arange(16)
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=False)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
def test_token_combine(self):
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
8)
self.dispatcher.topk_weights = torch.rand(8, 4)
self.dispatcher.input_splits = [4, 4]
self.dispatcher.output_splits = [4, 4]
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
[[2, 2], [2, 2]], dtype=torch.int64)
expert_output = torch.randn(16, 16)
output = self.dispatcher.token_combine(expert_output)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=True)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=True)
self.mock_repeat_interleave.return_value = torch.tensor(
[], dtype=torch.long)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
log2phy = torch.tensor([1, 0, 3, 2])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)

View File

@@ -22,7 +22,7 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Optional
import torch
import torch_npu
@@ -466,8 +466,9 @@ class MoETokenDispatcher(ABC):
"""
Initialize the MoE Token Dispatcher.
"""
self.top_k = kwargs.get("top_k")
self.num_experts = kwargs.get("num_experts")
self.top_k = kwargs.get("top_k", 0)
self.num_experts = kwargs.get("num_experts", 0)
self.with_quant = kwargs.get("with_quant", False)
@property
def ep_group(self):
@@ -483,25 +484,25 @@ class MoETokenDispatcher(ABC):
return get_ep_group().world_size
@abstractmethod
def token_permutation(
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: torch.Tensor = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
shared_experts: Optional[Any] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
raise NotImplementedError("Restore function not implemented.")
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher):
@@ -517,7 +518,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.ep_world_size = get_mc2_group().world_size
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.with_quant = kwargs.get("with_quant")
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
@@ -535,12 +535,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.topk_weights = None
self.shared_experts = None
def get_permute_mc2_kwargs(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0):
def get_dispatch_mc2_kwargs(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0):
quant_mode = 0
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
@@ -581,26 +581,26 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_permutation(
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: torch.Tensor = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
shared_experts: Optional[Any] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
self.shared_experts = shared_experts
kwargs_mc2 = self.get_permute_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
global_redundant_expert_num)
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
global_redundant_expert_num)
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
@@ -629,7 +629,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
group_list_type = 1
return group_list_type, expand_x, expert_token_nums
def get_unpermute_mc_kwargs(self, hidden_states: torch.Tensor):
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None
assert self.topk_weights is not None
assert self.topk_ids is not None
@@ -682,11 +682,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_unpermutation(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
kwargs_mc2 = self.get_unpermute_mc_kwargs(hidden_states)
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
@@ -718,7 +718,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
ep_size = kwargs.get("ep_size")
if ep_size is not None:
self.num_experts_local = self.num_experts // ep_size
self.with_quant = kwargs.get("with_quant")
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
@@ -728,17 +727,17 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.topk_weights = None
self.topk_ids = None
def token_permutation(
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: torch.Tensor = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
shared_experts: Optional[Any] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
self.original_shape = hidden_states.shape
# assert len(original_shape) == 2
@@ -830,9 +829,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
group_list_type = 0
return group_list_type, sorted_hidden_states, expert_tokens
def token_unpermutation(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
@@ -901,23 +900,22 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
"apply_router_weight_on_input")
ep_size = kwargs.get("ep_size")
self.local_ep = ep_size
self.top_k = kwargs.get("top_k")
assert self.local_ep is not None
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.bsz = None
def token_permutation(
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: torch.Tensor = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
shared_experts: Optional[Any] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
if self.apply_router_weight_on_input:
@@ -949,9 +947,9 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
return hidden_states, group_list
def token_unpermutation(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.local_ep is not None
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
@@ -960,3 +958,253 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.local_ep, -1).sum(1)
return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
"""
The implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
self.num_experts = self.num_experts + self.num_global_redundant_experts
self.hidden_shape = None
self.topk_weights = None
self.input_splits = None
self.output_splits = None
self.hidden_shape_before_permute = None
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = None
# cached intermediate tensors.
self.tokens_per_expert = None
self.global_input_tokens_local_experts_indices = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.num_experts)],
dtype=torch.int32,
device=torch.npu.current_device(),
)
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
self.local_expert_indices = [
local_expert_indices_offset + i
for i in range(self.num_local_experts)
]
assert (len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (self.local_expert_indices[i] ==
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
if log2phy is not None:
topk_ids = log2phy[topk_ids]
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
hidden_states, topk_ids)
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale,
self.output_splits,
self.input_splits,
self.ep_group,
)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
self.ep_group,
)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all)
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"dynamic_scale": dynamic_scale,
"group_list_type": 1
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
hidden_states = self._combine_preprocess(hidden_states)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, self.input_splits, self.output_splits,
self.ep_group)
handle.wait()
hidden_states.untyped_storage().resize_(0)
output = self._combine_postprocess(permutated_local_input_tokens)
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
tokens=hidden_states,
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
)
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size
# Dropless
self.num_out_tokens = topk_ids.numel()
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
if self.num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
if self.num_local_experts > 1:
if self.num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel())
return num_tokens_per_local_expert
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, None
# Handle quantized case
if self.with_quant:
assert self.global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
-1)
active_num = self.global_input_tokens_local_experts_indices.numel()
# Handle case with no active tokens
if active_num <= 0:
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale
# Process with active tokens
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0)
return global_input_tokens, expanded_scale
# Handle non-quantized case
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens,
self.global_input_tokens_local_experts_indices)
return global_input_tokens, None
def _combine_preprocess(self, hidden_states):
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, self.reversed_global_input_permutation_mapping)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens):
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=self.reversed_local_input_permutation_mapping.to(
torch.int32),
probs=self.topk_weights,
restore_shape=self.hidden_shape_before_permute)
# Reshape the output tensor
output = output.view(self.hidden_shape)
return output