diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index b978eaf3a..5ba8d2c42 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch @@ -50,6 +52,13 @@ from sglang.srt.utils import ( next_power_of_2, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.ep_moe.token_dispatcher import ( + DeepEPLLOutput, + DeepEPNormalOutput, + DispatchOutput, + ) + _is_hip = is_hip() _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() @@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE): "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable." ) + # TODO: move to the beginning of the file + from sglang.srt.distributed.parallel_state import get_tp_group + from sglang.srt.managers.schedule_batch import global_server_args_dict + from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher + + self.deepep_dispatcher = MaybeTboDeepEPDispatcher( + group=get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=self.num_experts, + num_local_experts=self.num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + deepep_mode=deepep_mode, + async_finish=True, # TODO + return_recv_hook=True, + ) + if self.deepep_mode.enable_low_latency(): assert ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM @@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE): hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, - masked_m: torch.Tensor, - expected_m: int, - num_recv_tokens_per_expert: List[int], forward_batch: ForwardBatch, ): + dispatch_output = self.dispatch( + hidden_states, topk_idx, topk_weights, forward_batch + ) + hidden_states = self.moe_impl(dispatch_output) + hidden_states = self.combine( + hidden_states, + dispatch_output.topk_idx, + dispatch_output.topk_weights, + forward_batch, + ) + return hidden_states + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, + ): + return self.deepep_dispatcher.dispatch( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_batch=forward_batch, + ) + + def moe_impl(self, dispatch_output: DispatchOutput): if _use_aiter: # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel - return self.forward_aiter(hidden_states, topk_idx, topk_weights) - resolved_deepep_mode = self.deepep_mode.resolve( - forward_batch.is_extend_in_batch - ) - if resolved_deepep_mode == DeepEPMode.normal: + return self.forward_aiter(dispatch_output) + if dispatch_output.format.is_deepep_normal(): if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: - return self.forward_deepgemm_contiguous( - hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert - ) + return self.forward_deepgemm_contiguous(dispatch_output) else: - return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) - elif resolved_deepep_mode == DeepEPMode.low_latency: - return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) + return self.forward_normal(dispatch_output) + elif dispatch_output.format.is_deepep_ll(): + return self.forward_deepgemm_masked(dispatch_output) else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") - def forward_normal( + def combine( self, hidden_states: torch.Tensor, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, ): + return self.deepep_dispatcher.combine( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_batch=forward_batch, + ) + + def _prepare_for_normal( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + ): + from sglang.srt.layers.moe.ep_moe.kernels import ( + deepep_permute_triton_kernel, + deepep_run_moe_deep_preprocess, + ) + + if hidden_states.shape[0] == 0: + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (self.num_experts + 1,), + device=hidden_states.device, + dtype=torch.int64, + ) + return reorder_topk_ids, seg_indptr, hidden_states + else: + if _use_aiter: + # skip permutation here as aiter fused_moe has fused inside + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (self.num_experts + 1,), + device=hidden_states.device, + dtype=torch.int64, + ) + return reorder_topk_ids, seg_indptr, hidden_states + + reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( + topk_idx, self.num_experts + ) + num_total_tokens = reorder_topk_ids.numel() + gateup_input = torch.empty( + (int(num_total_tokens), hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + # PreReorder + deepep_permute_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + self.src2dst, + topk_idx, + None, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + return reorder_topk_ids, seg_indptr, gateup_input + + def forward_normal( + self, + dispatch_output: DeepEPNormalOutput, + ): + hidden_states, topk_idx = ( + dispatch_output.hidden_states, + dispatch_output.topk_idx, + ) + reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal( + hidden_states, topk_idx + ) hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device @@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE): def forward_aiter( self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, + dispatch_output: DeepEPNormalOutput, ): + hidden_states, topk_idx, topk_weights = ( + dispatch_output.hidden_states, + dispatch_output.topk_idx, + dispatch_output.topk_weights, + ) if hidden_states.shape[0] == 0: return hidden_states # in original deepep, idx == -1 meaning invalid and will not be processed. @@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE): def forward_deepgemm_contiguous( self, - hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], - topk_idx, - topk_weights, - num_recv_tokens_per_expert: List[int], + dispatch_output: DeepEPNormalOutput, ): + hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = ( + dispatch_output + ) hidden_states_fp8, hidden_states_scale = hidden_states_fp8 assert self.quant_method is not None assert self.activation == "silu" @@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE): def forward_deepgemm_masked( self, - hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], - masked_m: torch.Tensor, - expected_m: int, + dispatch_output: DeepEPLLOutput, ): + hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output assert self.quant_method is not None assert self.activation == "silu" diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 5c0cd3ec9..b1aee3a93 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,7 +1,27 @@ +# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py + +from __future__ import annotations + import logging from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + List, + NamedTuple, + Optional, + Protocol, + Tuple, + Union, + runtime_checkable, +) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( + BaseDispatcher, + BaseDispatcherConfig, + DispatchOutput, + DispatchOutputFormat, +) from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( @@ -24,7 +44,6 @@ except ImportError: use_deepep = False from enum import Enum, IntEnum, auto -from typing import Optional, Tuple, Union import torch import torch.distributed as dist @@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() logger = logging.getLogger(__name__) +class DeepEPNormalOutput(NamedTuple): + """DeepEP normal dispatch output.""" + + hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] + topk_idx: torch.Tensor + topk_weights: torch.Tensor + num_recv_tokens_per_expert: List[int] + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.deepep_normal + + +class DeepEPLLOutput(NamedTuple): + """DeepEP low latency dispatch output.""" + + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] + topk_idx: torch.Tensor + topk_weights: torch.Tensor + masked_m: torch.Tensor + expected_m: int + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.deepep_ll + + +assert isinstance(DeepEPNormalOutput, DispatchOutput) +assert isinstance(DeepEPLLOutput, DispatchOutput) + + class DeepEPDispatchMode(IntEnum): NORMAL = auto() LOW_LATENCY = auto() @@ -139,7 +189,7 @@ class DeepEPBuffer: cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY -class DeepEPConfig: +class DeepEPConfig(BaseDispatcherConfig): _instance = None def __init__(self): @@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): return hidden_states, topk_idx, topk_weights, previous_event def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: - ( - hidden_states, - topk_idx, - topk_weights, - num_recv_tokens_per_expert_list, - event, - ) = self._dispatch_core( - hidden_states, topk_idx, topk_weights, previous_event - ) - event.current_stream_wait() if self.async_finish else () - return ( - hidden_states, - topk_idx, - topk_weights, - None, - num_recv_tokens_per_expert_list, - None, - None, - None, - ) - else: - ( - hidden_states, - topk_idx, - topk_weights, - num_recv_tokens_per_expert_list, - event, - ) = self._dispatch_core( - hidden_states, topk_idx, topk_weights, previous_event - ) - event.current_stream_wait() if self.async_finish else () - if hidden_states.shape[0] > 0: - reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( - hidden_states, topk_idx, fp8_dtype=hidden_states.dtype - ) - else: - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (self.num_experts + 1,), - device=hidden_states.device, - dtype=torch.int64, - ) - - masked_m = expected_m = None - return ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - None, - seg_indptr, - masked_m, - expected_m, - ) + ( + hidden_states, + topk_idx, + topk_weights, + num_recv_tokens_per_expert, + event, + ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event) + event.current_stream_wait() if self.async_finish else () + return DeepEPNormalOutput( + hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert + ) def _dispatch_core( self, @@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): recv_x, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, + num_recv_tokens_per_expert, self.handle, event, ) = buffer.dispatch( @@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ) get_global_expert_distribution_recorder().on_deepep_dispatch_normal( - num_recv_tokens_per_expert_list, + num_recv_tokens_per_expert, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_expert=num_tokens_per_expert, @@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): recv_x, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, + num_recv_tokens_per_expert, event, ) - def _deepep_permute( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - fp8_dtype: Optional[torch.dtype] = None, - use_fp8_w8a8: bool = False, - use_block_quant: bool = False, - ): - """ - Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py - """ - if _use_aiter: - # skip permutation here as aiter fused_moe has fused inside - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - ) - return reorder_topk_ids, seg_indptr, hidden_states - - reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - topk_idx, self.num_experts - ) - num_total_tokens = reorder_topk_ids.numel() - gateup_input = torch.empty( - (int(num_total_tokens), hidden_states.shape[1]), - device=hidden_states.device, - dtype=( - fp8_dtype - if (use_fp8_w8a8 and not use_block_quant) - else hidden_states.dtype - ), - ) - # PreReorder - deepep_permute_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - self.src2dst, - topk_idx, - None, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - return reorder_topk_ids, seg_indptr, gateup_input - def combine_a( self, hidden_states: torch.Tensor, @@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): masked_m ) - reorder_topk_ids = seg_indptr = None - - return ( + return DeepEPLLOutput( hidden_states, topk_idx, topk_weights, - reorder_topk_ids, - None, - seg_indptr, masked_m, expected_m, ) @@ -636,7 +587,7 @@ class _Stage(Enum): AFTER_COMBINE_A = auto() -class DeepEPDispatcher: +class DeepEPDispatcher(BaseDispatcher): def __init__( self, group: torch.distributed.ProcessGroup, @@ -676,7 +627,7 @@ class DeepEPDispatcher: self._stage = _Stage.INITIAL - def dispatch(self, *args, **kwargs) -> Tuple: + def dispatch(self, *args, **kwargs) -> DispatchOutput: self.dispatch_a(*args, **kwargs) ret = self.dispatch_b() return ret diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py b/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py new file mode 100644 index 000000000..7167fe759 --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum, auto +from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable + +import torch + + +class DispatchOutputFormat(Enum): + standard = auto() + deepep_normal = auto() + deepep_ll = auto() + + def is_standard(self) -> bool: + return self == DispatchOutputFormat.standard + + def is_deepep_normal(self) -> bool: + return self == DispatchOutputFormat.deepep_normal + + def is_deepep_ll(self) -> bool: + return self == DispatchOutputFormat.deepep_ll + + +@runtime_checkable +class DispatchOutput(Protocol): + """Protocol for dispatch outputs in different formats.""" + + @property + def format(self) -> DispatchOutputFormat: ... + + +class BaseDispatcherConfig(ABC): + """Base class for dispatcher configs.""" + + pass + + +class BaseDispatcher(ABC): + """Base class for dispatchers.""" + + @abstractmethod + def dispatch(self, *args, **kwargs) -> DispatchOutput: + pass + + @abstractmethod + def combine(self, *args, **kwargs) -> torch.Tensor: + pass diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py new file mode 100644 index 000000000..4a2d2dd6b --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import NamedTuple + +from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( + DispatchOutput, + DispatchOutputFormat, +) + + +class StandardDispatchOutput(NamedTuple): + """Standard dispatch output.""" + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.standard + + +assert isinstance(StandardDispatchOutput, DispatchOutput) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 777b8e0c8..b5305f923 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module): topk_weights = torch.empty( (0, self.top_k), dtype=torch.float32, device=hidden_states.device ) - if self.ep_size > 1: - # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value - ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - num_recv_tokens_per_expert, - seg_indptr, - masked_m, - expected_m, - ) = self.deepep_dispatcher.dispatch( - hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, - ) + final_hidden_states = self.experts( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - num_recv_tokens_per_expert=num_recv_tokens_per_expert, forward_batch=forward_batch, ) - if self.ep_size > 1: - final_hidden_states = self.deepep_dispatcher.combine( - hidden_states=final_hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, - ) if shared_output is not None: x = shared_output @@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module): def op_dispatch_a(self, state): if self.ep_size > 1: - # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value - self.deepep_dispatcher.dispatch_a( + self.experts.deepep_dispatcher.dispatch_a( hidden_states=state.hidden_states_mlp_input, topk_idx=state.pop("topk_idx_local"), topk_weights=state.pop("topk_weights_local"), @@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - ( - state.hidden_states_experts_input, - state.topk_idx_dispatched, - state.topk_weights_dispatched, - state.reorder_topk_ids, - state.num_recv_tokens_per_expert, - state.seg_indptr, - state.masked_m, - state.expected_m, - ) = self.deepep_dispatcher.dispatch_b( + state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b( tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_experts(self, state): - state.hidden_states_experts_output = self.experts( - hidden_states=state.pop("hidden_states_experts_input"), - topk_idx=state.topk_idx_dispatched, - topk_weights=state.topk_weights_dispatched, - reorder_topk_ids=state.pop("reorder_topk_ids"), - seg_indptr=state.pop("seg_indptr"), - masked_m=state.pop("masked_m"), - expected_m=state.pop("expected_m"), - num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), - forward_batch=state.forward_batch, + state.hidden_states_experts_output = self.experts.moe_impl( + dispatch_output=state.dispatch_output, ) def op_combine_a(self, state): if self.ep_size > 1: - self.deepep_dispatcher.combine_a( + self.experts.deepep_dispatcher.combine_a( hidden_states=state.pop("hidden_states_experts_output"), - topk_idx=state.pop("topk_idx_dispatched"), - topk_weights=state.pop("topk_weights_dispatched"), + topk_idx=state.dispatch_output.topk_idx, + topk_weights=state.dispatch_output.topk_weights, forward_batch=state.forward_batch, tbo_subbatch_index=state.get("tbo_subbatch_index"), ) + state.pop("dispatch_output") def op_combine_b(self, state): if self.ep_size > 1: - state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( - tbo_subbatch_index=state.get("tbo_subbatch_index"), + state.hidden_states_after_combine = ( + self.experts.deepep_dispatcher.combine_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) ) def op_output(self, state): diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 01235f7ac..a1faa894d 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ) self.top_k = config.num_experts_per_tok - self.deepep_dispatcher = MaybeTboDeepEPDispatcher( - group=parallel_state.get_tp_group().device_group, - router_topk=self.top_k, - permute_fusion=True, - num_experts=self.num_experts, - num_local_experts=config.num_experts // self.tp_size, - hidden_size=config.hidden_size, - params_dtype=config.torch_dtype, - deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], - async_finish=True, # TODO - return_recv_hook=True, - ) - def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: @@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module): topk_weights = torch.empty( (0, self.top_k), dtype=torch.float32, device=hidden_states.device ) - if self.ep_size > 1: - # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value - ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - num_recv_tokens_per_expert, - seg_indptr, - masked_m, - expected_m, - ) = self.deepep_dispatcher.dispatch( - hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, - ) final_hidden_states = self.experts( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - num_recv_tokens_per_expert=num_recv_tokens_per_expert, forward_batch=forward_batch, ) - if self.ep_size > 1: - final_hidden_states = self.deepep_dispatcher.combine( - hidden_states=final_hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - forward_batch=forward_batch, - ) return final_hidden_states def op_gate(self, state): @@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def op_dispatch_a(self, state): if self.ep_size > 1: - # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value - self.deepep_dispatcher.dispatch_a( + self.experts.deepep_dispatcher.dispatch_a( hidden_states=state.pop("hidden_states_mlp_input"), topk_idx=state.pop("topk_idx_local"), topk_weights=state.pop("topk_weights_local"), @@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - ( - state.hidden_states_experts_input, - state.topk_idx_dispatched, - state.topk_weights_dispatched, - state.reorder_topk_ids, - state.num_recv_tokens_per_expert, - state.seg_indptr, - state.masked_m, - state.expected_m, - ) = self.deepep_dispatcher.dispatch_b( + state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b( tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_experts(self, state): - state.hidden_states_experts_output = self.experts( - hidden_states=state.pop("hidden_states_experts_input"), - topk_idx=state.topk_idx_dispatched, - topk_weights=state.topk_weights_dispatched, - reorder_topk_ids=state.pop("reorder_topk_ids"), - seg_indptr=state.pop("seg_indptr"), - masked_m=state.pop("masked_m"), - expected_m=state.pop("expected_m"), - num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), - forward_batch=state.forward_batch, + state.hidden_states_experts_output = self.experts.moe_impl( + dispatch_output=state.dispatch_output, ) def op_combine_a(self, state): if self.ep_size > 1: - self.deepep_dispatcher.combine_a( + self.experts.deepep_dispatcher.combine_a( hidden_states=state.pop("hidden_states_experts_output"), - topk_idx=state.pop("topk_idx_dispatched"), - topk_weights=state.pop("topk_weights_dispatched"), + topk_idx=state.dispatch_output.topk_idx, + topk_weights=state.dispatch_output.topk_weights, forward_batch=state.forward_batch, tbo_subbatch_index=state.get("tbo_subbatch_index"), ) + state.pop("dispatch_output") def op_combine_b(self, state): if self.ep_size > 1: - state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( - tbo_subbatch_index=state.get("tbo_subbatch_index"), + state.hidden_states_after_combine = ( + self.experts.deepep_dispatcher.combine_b( + tbo_subbatch_index=state.get("tbo_subbatch_index"), + ) ) def op_output(self, state): diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index e802a7254..d65d8d598 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import dataclasses import logging from dataclasses import replace -from typing import Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union import torch @@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var +if TYPE_CHECKING: + from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput + _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") logger = logging.getLogger(__name__) @@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher: def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) - def dispatch(self, **kwargs): + def dispatch(self, **kwargs) -> DispatchOutput: return self._execute("dispatch", **kwargs) def dispatch_a(self, **kwargs): @@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher: def dispatch_b(self, **kwargs): return self._execute("dispatch_b", **kwargs) - def combine(self, **kwargs): + def combine(self, **kwargs) -> torch.Tensor: return self._execute("combine", **kwargs) def combine_a(self, **kwargs):