diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index fe9fbad67..663ca1877 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.managers.expert_distribution import ( @@ -18,7 +19,7 @@ try: except ImportError: use_deepep = False -from enum import IntEnum, auto +from enum import Enum, IntEnum, auto from typing import Optional, Tuple, Union import torch @@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ) +@dataclass +class _Stage(Enum): + INITIAL = auto() + AFTER_DISPATCH_A = auto() + AFTER_DISPATCH_B = auto() + AFTER_COMBINE_A = auto() + + class DeepEPDispatcher: def __init__( self, @@ -665,6 +674,8 @@ class DeepEPDispatcher: **common_kwargs, ) + self._stage = _Stage.INITIAL + def dispatch(self, *args, **kwargs) -> Tuple: self.dispatch_a(*args, **kwargs) ret = self.dispatch_b() @@ -677,6 +688,7 @@ class DeepEPDispatcher: topk_weights: torch.Tensor, forward_mode: ForwardMode = None, ): + self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) inner_state = self._get_impl(forward_mode).dispatch_a( hidden_states=hidden_states, topk_idx=topk_idx, @@ -685,6 +697,7 @@ class DeepEPDispatcher: self._dispatch_intermediate_state = forward_mode, inner_state def dispatch_b(self): + self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) forward_mode, inner_state = self._dispatch_intermediate_state del self._dispatch_intermediate_state return self._get_impl(forward_mode).dispatch_b(*inner_state) @@ -701,6 +714,7 @@ class DeepEPDispatcher: topk_weights: torch.Tensor, forward_mode: ForwardMode, ): + self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) inner_state = self._get_impl(forward_mode).combine_a( hidden_states=hidden_states, topk_idx=topk_idx, @@ -709,6 +723,7 @@ class DeepEPDispatcher: self._combine_intermediate_state = forward_mode, inner_state def combine_b(self): + self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) forward_mode, inner_state = self._combine_intermediate_state del self._combine_intermediate_state return self._get_impl(forward_mode).combine_b(*inner_state) @@ -721,3 +736,7 @@ class DeepEPDispatcher: return self._low_latency_dispatcher else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + + def _update_stage(self, old_stage, new_stage): + assert self._stage == old_stage + self._stage = new_stage