[Refactor] Formatting output types related to FuseMoE (#5481)

Currently in the Fused MoE module, functions of classes like
MoECommMethod and MoETokenDispatcher output data in dictionary or tuple
format, which hampers code maintainability, readability, and
extensibility. This PR introduces dataclasses for these key output types
to address these issues.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-31 14:24:37 +08:00
committed by GitHub
parent 38570cfeb6
commit 7d5242faca
6 changed files with 155 additions and 212 deletions

View File

@@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
context_metadata = dispatch_output.context_metadata
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
@@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather(
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch.testing.assert_close(combined_output,
torch.testing.assert_close(combined_output.routed_out,
torch_output,
atol=4e-2,
rtol=1)
@@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant(
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
dynamic_scale = dispatch_output.dynamic_scale
context_metadata = dispatch_output.context_metadata
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
@@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k)
assert combined_output.routed_out.shape == (m, k)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -8,6 +8,8 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)
class TestMoECommMethod(TestBase):
@@ -178,12 +180,12 @@ class TestMoECommMethod(TestBase):
# Mock token dispatcher
mock_td_instance = MagicMock()
mock_td_instance.token_dispatch.return_value = {
"hidden_states": torch.randn(6, 8),
"group_list": torch.tensor([2, 2, 2]),
"group_list_type": 1
}
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
hidden_states=torch.randn(6, 8),
group_list=torch.tensor([2, 2, 2]),
group_list_type=1)
mock_td_instance.token_combine.return_value = TokenCombineResult(
routed_out=torch.randn(4, 8))
mock_token_dispatcher.return_value = mock_td_instance
# Mock unified_apply_mlp
@@ -213,7 +215,7 @@ class TestMoECommMethod(TestBase):
activation="silu")
# Verify result shape
self.assertEqual(result.shape, (4, 8))
self.assertEqual(result.routed_out.shape, (4, 8))
# Verify token_dispatch was called
mock_td_instance.token_dispatch.assert_called_once()

View File

@@ -97,8 +97,7 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_weights, topk_ids,
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
0) # group_list_type == 0
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
@@ -149,43 +148,6 @@ class TestTokenDispatcherWithMC2(TestBase):
context_metadata)
self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self):
shared_experts = MagicMock()
shared_experts.down_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": None,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"shared_experts": shared_experts,
"shared_act": torch.randn(10, 128),
"swiglu_out_scale": torch.randn(10, 1),
"tp_recv_counts": tp_recv_counts
}
self.dispatcher.with_quant = True
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
result = self.dispatcher.token_combine(hidden_states,
context_metadata)
self.assertIsInstance(result, tuple)
class TestTokenDispatcherWithAllGather(TestBase):
@@ -233,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
@@ -248,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.mock_npu_moe_init_routing_v2.assert_called_once()
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_without_quant(self):
kwargs = {
@@ -268,7 +230,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_weights, topk_ids,
None)
self.assertEqual(results["group_list_type"], 1)
self.assertEqual(results.group_list_type, 1)
def test_token_dispatch_with_quant(self):
kwargs = {
@@ -290,10 +252,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
None,
with_quant=True)
self.assertIsNotNone(results["hidden_states"])
self.assertIsNotNone(results["group_list"])
self.assertIsNotNone(results["dynamic_scale"])
self.assertEqual(results["group_list_type"], 1)
self.assertIsNotNone(results.hidden_states)
self.assertIsNotNone(results.group_list)
self.assertIsNotNone(results.dynamic_scale)
self.assertEqual(results.group_list_type, 1)
def test_token_combine_with_expert_map(self):
hidden_states = torch.randn(6, 128)
@@ -303,7 +265,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
hidden_states, context_metadata).routed_out
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
@@ -314,7 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
hidden_states, context_metadata).routed_out
self.mock_npu_moe_token_unpermute.assert_called_once()
self.assertEqual(final_hidden_states.shape, (6, 128))
@@ -326,7 +288,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
self.assertEqual(results.hidden_states.shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
@@ -437,9 +399,9 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
topk_ids=topk_ids,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
@@ -458,7 +420,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
self.assertEqual(output.routed_out.shape, (8, 16))
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
@@ -480,10 +442,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
@@ -508,10 +470,10 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertIsNotNone(result.dynamic_scale)
self.assertEqual(result.group_list_type, 1)
def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
@@ -530,6 +492,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
self.assertIsNotNone(result.hidden_states)
self.assertIsNotNone(result.group_list)
self.assertEqual(result.group_list_type, 1)

View File

@@ -37,6 +37,7 @@ from vllm_ascend.flash_common3_context import (get_flash_common3_context,
set_flash_common3_context)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.w4a8_dynamic import \
@@ -325,7 +326,7 @@ class AscendFusedMoE(FusedMoE):
pertoken_scale = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
@@ -350,25 +351,25 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb:
if self.dynamic_eplb:
expert_tokens = fused_experts_results.expert_tokens
group_list_type = fused_experts_results.group_list_type
assert expert_tokens is not None and group_list_type is not None, \
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()
moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)
moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()
moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata)
return final_hidden_states
return routed_out
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
@@ -439,7 +440,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
else:
set_flash_common3_context(shared_experts=self._shared_experts)
fused_output = AscendFusedMoE.forward_impl(
routed_out = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
@@ -462,4 +463,4 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
assert fc3_context is not None
shared_out = fc3_context.shared_out
return shared_out, fused_output
return shared_out, routed_out

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
@@ -26,11 +27,11 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2, QuantType)
PrepareAndFinalize, PrepareAndFinalizeWithAll2All,
PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
MoETokenDispatcher, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -47,6 +48,14 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
@dataclass
class FusedExpertsResult:
routed_out: torch.Tensor
# For dynamic_eplb
group_list_type: int | None = None
expert_tokens: torch.Tensor | None = None
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
@@ -118,7 +127,7 @@ class MoECommMethod(ABC):
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
results = self.token_dispatcher.token_dispatch(
dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
@@ -134,43 +143,41 @@ class MoECommMethod(ABC):
dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=dispatch_results.group_list,
dynamic_scale=dispatch_results.dynamic_scale,
group_list_type=dispatch_results.group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
topk_scales=dispatch_results.topk_scales,
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
fusion=use_int8_w8a8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8 or use_int4_w4a16,
fusion=use_int8_w8a8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output,
context_metadata=dispatch_results.context_metadata)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=context_metadata)
if dynamic_eplb:
return (final_hidden_states, group_list_type, expert_tokens)
return final_hidden_states
return FusedExpertsResult(
routed_out=combine_results.routed_out,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list)
@abstractmethod
def _get_token_dispatcher(self):
def _get_token_dispatcher(self) -> MoETokenDispatcher:
raise NotImplementedError(
"_get_token_dispatcher function not implemented.")
@abstractmethod
def _get_prepare_finalize(self):
def _get_prepare_finalize(self) -> PrepareAndFinalize:
raise NotImplementedError(
"_get_prepare_finalize function not implemented.")
@@ -292,9 +299,11 @@ class FusedMC2CommImpl(MoECommMethod):
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
x=hidden_states,
weight1=w1[0],
weight2=w2[0],
@@ -308,7 +317,7 @@ class FusedMC2CommImpl(MoECommMethod):
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
assert expert_map is not None, "expert_map cannot be None."
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
x=hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1[0],
@@ -325,4 +334,4 @@ class FusedMC2CommImpl(MoECommMethod):
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return out
return FusedExpertsResult(routed_out=out)

View File

@@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
@@ -35,6 +36,21 @@ from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
is_hierarchical_communication_enabled)
@dataclass
class TokenDispatchResult:
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
dynamic_scale: torch.Tensor | None = field(default=None)
topk_scales: torch.Tensor | None = field(default=None)
context_metadata: dict = field(default_factory=dict)
@dataclass
class TokenCombineResult:
routed_out: torch.Tensor
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
@@ -74,14 +90,14 @@ class MoETokenDispatcher(ABC):
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None,
):
) -> TokenDispatchResult:
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
bias: torch.Tensor | None = None) -> TokenCombineResult:
raise NotImplementedError("Combine function not implemented.")
@@ -207,24 +223,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
# Handle shared experts (store intermediate results in local vars, not self)
shared_act = None
swiglu_out_scale = None
if with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[
0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
shared_act = shared_experts.act_fn(shared_gate_up)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
@@ -233,20 +231,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"shared_act": shared_act,
"swiglu_out_scale": swiglu_out_scale,
"expand_scales": expand_scales
}
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
"context_metadata": context_metadata
}
return TokenDispatchResult(hidden_states=expand_x,
dynamic_scale=dynamic_scale,
group_list=expert_token_nums,
group_list_type=group_list_type,
context_metadata=context_metadata)
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
context_metadata: dict):
@@ -300,12 +294,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
def token_combine(self, hidden_states, context_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
@@ -313,20 +302,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
# Handle shared experts from metadata
shared_experts = context_metadata["shared_experts"]
if shared_experts is None:
return combined_output
shared_act = context_metadata["shared_act"]
if self.with_quant:
swiglu_out_scale = context_metadata["swiglu_out_scale"]
shared_hidden_states, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
else:
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return combined_output, shared_hidden_states
return TokenCombineResult(routed_out=combined_output)
class TokenDispatcherWithAllGather(MoETokenDispatcher):
@@ -401,18 +377,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None,
"context_metadata": context_metadata
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
return TokenDispatchResult(
hidden_states=sorted_hidden_states,
dynamic_scale=pertoken_scale if self.with_quant else None,
group_list=expert_tokens,
group_list_type=group_list_type,
context_metadata=context_metadata,
)
def token_combine(self, hidden_states, context_metadata, bias=None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
@@ -422,7 +396,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
final_hidden_states = final_hidden_states.view(self.original_shape)
# these values are no longer used, so they need to be set to None for memory release.
return final_hidden_states
return TokenCombineResult(routed_out=final_hidden_states)
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
@@ -530,20 +504,15 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
reversed_global_input_permutation_mapping
}
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"group_list_type": 1,
"dynamic_scale": dynamic_scale_final,
"context_metadata": context_metadata,
}
return TokenDispatchResult(
hidden_states=global_input_tokens,
dynamic_scale=dynamic_scale_final,
group_list=tokens_per_expert,
group_list_type=1,
context_metadata=context_metadata,
)
def token_combine(
self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
def token_combine(self, hidden_states, context_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata
@@ -564,7 +533,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
output = self._combine_postprocess(permutated_local_input_tokens,
context_metadata)
return output
return TokenCombineResult(routed_out=output)
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None