Tiny add stage assertions to DeepEPDispatcher to avoid misuse (#6467)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
@@ -18,7 +19,7 @@ try:
|
||||
except ImportError:
|
||||
use_deepep = False
|
||||
|
||||
from enum import IntEnum, auto
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Stage(Enum):
|
||||
INITIAL = auto()
|
||||
AFTER_DISPATCH_A = auto()
|
||||
AFTER_DISPATCH_B = auto()
|
||||
AFTER_COMBINE_A = auto()
|
||||
|
||||
|
||||
class DeepEPDispatcher:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -665,6 +674,8 @@ class DeepEPDispatcher:
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
self._stage = _Stage.INITIAL
|
||||
|
||||
def dispatch(self, *args, **kwargs) -> Tuple:
|
||||
self.dispatch_a(*args, **kwargs)
|
||||
ret = self.dispatch_b()
|
||||
@@ -677,6 +688,7 @@ class DeepEPDispatcher:
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode = None,
|
||||
):
|
||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||
inner_state = self._get_impl(forward_mode).dispatch_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
@@ -685,6 +697,7 @@ class DeepEPDispatcher:
|
||||
self._dispatch_intermediate_state = forward_mode, inner_state
|
||||
|
||||
def dispatch_b(self):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||
forward_mode, inner_state = self._dispatch_intermediate_state
|
||||
del self._dispatch_intermediate_state
|
||||
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
||||
@@ -701,6 +714,7 @@ class DeepEPDispatcher:
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_mode).combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
@@ -709,6 +723,7 @@ class DeepEPDispatcher:
|
||||
self._combine_intermediate_state = forward_mode, inner_state
|
||||
|
||||
def combine_b(self):
|
||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||
forward_mode, inner_state = self._combine_intermediate_state
|
||||
del self._combine_intermediate_state
|
||||
return self._get_impl(forward_mode).combine_b(*inner_state)
|
||||
@@ -721,3 +736,7 @@ class DeepEPDispatcher:
|
||||
return self._low_latency_dispatcher
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
def _update_stage(self, old_stage, new_stage):
|
||||
assert self._stage == old_stage
|
||||
self._stage = new_stage
|
||||
|
||||
Reference in New Issue
Block a user