diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 825ba46c6..8ddad6096 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -86,7 +86,7 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: a2a_backend = get_moe_a2a_backend() if a2a_backend.is_none(): return StandardDispatcher(moe_runner_config) - elif a2a_backend.is_deepep(): + elif a2a_backend.is_deepep() or a2a_backend.is_mooncake(): return MaybeTboDeepEPDispatcher( group=get_tp_group().device_group, router_topk=moe_runner_config.top_k, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index 201c1b5f2..f195a7994 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -36,7 +36,7 @@ class MooncakeDispatchOutput(NamedTuple): """Mooncake EP dispatch output.""" hidden_states: torch.Tensor - hidden_states_scale: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] topk_ids: torch.Tensor topk_weights: torch.Tensor masked_m: torch.Tensor @@ -205,8 +205,14 @@ class _MooncakeEPDispatcherImpl: masked_m ) + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + else: + hidden_states_scale = None + return MooncakeDispatchOutput( hidden_states, + hidden_states_scale, topk_ids, topk_weights, masked_m,