Tiny add stage assertions to DeepEPDispatcher to avoid misuse (#6467)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.managers.expert_distribution import (
|
from sglang.srt.managers.expert_distribution import (
|
||||||
@@ -18,7 +19,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
use_deepep = False
|
use_deepep = False
|
||||||
|
|
||||||
from enum import IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Stage(Enum):
|
||||||
|
INITIAL = auto()
|
||||||
|
AFTER_DISPATCH_A = auto()
|
||||||
|
AFTER_DISPATCH_B = auto()
|
||||||
|
AFTER_COMBINE_A = auto()
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatcher:
|
class DeepEPDispatcher:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -665,6 +674,8 @@ class DeepEPDispatcher:
|
|||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._stage = _Stage.INITIAL
|
||||||
|
|
||||||
def dispatch(self, *args, **kwargs) -> Tuple:
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
||||||
self.dispatch_a(*args, **kwargs)
|
self.dispatch_a(*args, **kwargs)
|
||||||
ret = self.dispatch_b()
|
ret = self.dispatch_b()
|
||||||
@@ -677,6 +688,7 @@ class DeepEPDispatcher:
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode = None,
|
forward_mode: ForwardMode = None,
|
||||||
):
|
):
|
||||||
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||||
inner_state = self._get_impl(forward_mode).dispatch_a(
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
@@ -685,6 +697,7 @@ class DeepEPDispatcher:
|
|||||||
self._dispatch_intermediate_state = forward_mode, inner_state
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
||||||
|
|
||||||
def dispatch_b(self):
|
def dispatch_b(self):
|
||||||
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||||
forward_mode, inner_state = self._dispatch_intermediate_state
|
forward_mode, inner_state = self._dispatch_intermediate_state
|
||||||
del self._dispatch_intermediate_state
|
del self._dispatch_intermediate_state
|
||||||
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
||||||
@@ -701,6 +714,7 @@ class DeepEPDispatcher:
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
):
|
):
|
||||||
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||||
inner_state = self._get_impl(forward_mode).combine_a(
|
inner_state = self._get_impl(forward_mode).combine_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
@@ -709,6 +723,7 @@ class DeepEPDispatcher:
|
|||||||
self._combine_intermediate_state = forward_mode, inner_state
|
self._combine_intermediate_state = forward_mode, inner_state
|
||||||
|
|
||||||
def combine_b(self):
|
def combine_b(self):
|
||||||
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||||
forward_mode, inner_state = self._combine_intermediate_state
|
forward_mode, inner_state = self._combine_intermediate_state
|
||||||
del self._combine_intermediate_state
|
del self._combine_intermediate_state
|
||||||
return self._get_impl(forward_mode).combine_b(*inner_state)
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
||||||
@@ -721,3 +736,7 @@ class DeepEPDispatcher:
|
|||||||
return self._low_latency_dispatcher
|
return self._low_latency_dispatcher
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
|
|
||||||
|
def _update_stage(self, old_stage, new_stage):
|
||||||
|
assert self._stage == old_stage
|
||||||
|
self._stage = new_stage
|
||||||
|
|||||||
Reference in New Issue
Block a user