[3/N] MoE Refactor: Simplify DeepEP Output (#8421)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
DispatchOutput,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE):
|
||||
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
|
||||
)
|
||||
|
||||
# TODO: move to the beginning of the file
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
deepep_mode=deepep_mode,
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
dispatch_output = self.dispatch(
|
||||
hidden_states, topk_idx, topk_weights, forward_batch
|
||||
)
|
||||
hidden_states = self.moe_impl(dispatch_output)
|
||||
hidden_states = self.combine(
|
||||
hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_weights,
|
||||
forward_batch,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
return self.deepep_dispatcher.dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||
if _use_aiter:
|
||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
return self.forward_aiter(dispatch_output)
|
||||
if dispatch_output.format.is_deepep_normal():
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
return self.forward_deepgemm_contiguous(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
)
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
else:
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||
return self.forward_normal(dispatch_output)
|
||||
elif dispatch_output.format.is_deepep_ll():
|
||||
return self.forward_deepgemm_masked(dispatch_output)
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
def forward_normal(
|
||||
def combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
return self.deepep_dispatcher.combine(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
def _prepare_for_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
):
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_permute_triton_kernel,
|
||||
deepep_run_moe_deep_preprocess,
|
||||
)
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(self.num_experts + 1,),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, hidden_states
|
||||
else:
|
||||
if _use_aiter:
|
||||
# skip permutation here as aiter fused_moe has fused inside
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(self.num_experts + 1,),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, hidden_states
|
||||
|
||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
topk_idx, self.num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input = torch.empty(
|
||||
(int(num_total_tokens), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
# PreReorder
|
||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
None,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, gateup_input
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states, topk_idx = (
|
||||
dispatch_output.hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
)
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
||||
hidden_states, topk_idx
|
||||
)
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
|
||||
@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_aiter(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states, topk_idx, topk_weights = (
|
||||
dispatch_output.hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_weights,
|
||||
)
|
||||
if hidden_states.shape[0] == 0:
|
||||
return hidden_states
|
||||
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||
@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_deepgemm_contiguous(
|
||||
self,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||
dispatch_output
|
||||
)
|
||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
):
|
||||
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
|
||||
@@ -1,7 +1,27 @@
|
||||
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
@@ -24,7 +44,6 @@ except ImportError:
|
||||
use_deepep = False
|
||||
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepEPNormalOutput(NamedTuple):
|
||||
"""DeepEP normal dispatch output."""
|
||||
|
||||
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
num_recv_tokens_per_expert: List[int]
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.deepep_normal
|
||||
|
||||
|
||||
class DeepEPLLOutput(NamedTuple):
|
||||
"""DeepEP low latency dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
expected_m: int
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.deepep_ll
|
||||
|
||||
|
||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||
|
||||
|
||||
class DeepEPDispatchMode(IntEnum):
|
||||
NORMAL = auto()
|
||||
LOW_LATENCY = auto()
|
||||
@@ -139,7 +189,7 @@ class DeepEPBuffer:
|
||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||
|
||||
|
||||
class DeepEPConfig:
|
||||
class DeepEPConfig(BaseDispatcherConfig):
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
return hidden_states, topk_idx, topk_weights, previous_event
|
||||
|
||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
event,
|
||||
) = self._dispatch_core(
|
||||
hidden_states, topk_idx, topk_weights, previous_event
|
||||
)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
None,
|
||||
num_recv_tokens_per_expert_list,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
event,
|
||||
) = self._dispatch_core(
|
||||
hidden_states, topk_idx, topk_weights, previous_event
|
||||
)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(self.num_experts + 1,),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
masked_m = expected_m = None
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
None,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert,
|
||||
event,
|
||||
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
return DeepEPNormalOutput(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
)
|
||||
|
||||
def _dispatch_core(
|
||||
self,
|
||||
@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
num_recv_tokens_per_expert,
|
||||
self.handle,
|
||||
event,
|
||||
) = buffer.dispatch(
|
||||
@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
)
|
||||
|
||||
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||
num_recv_tokens_per_expert_list,
|
||||
num_recv_tokens_per_expert,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert=num_tokens_per_expert,
|
||||
@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
num_recv_tokens_per_expert,
|
||||
event,
|
||||
)
|
||||
|
||||
def _deepep_permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
fp8_dtype: Optional[torch.dtype] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_block_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
||||
"""
|
||||
if _use_aiter:
|
||||
# skip permutation here as aiter fused_moe has fused inside
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, hidden_states
|
||||
|
||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
topk_idx, self.num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input = torch.empty(
|
||||
(int(num_total_tokens), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
# PreReorder
|
||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
None,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return reorder_topk_ids, seg_indptr, gateup_input
|
||||
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
masked_m
|
||||
)
|
||||
|
||||
reorder_topk_ids = seg_indptr = None
|
||||
|
||||
return (
|
||||
return DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
None,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
@@ -636,7 +587,7 @@ class _Stage(Enum):
|
||||
AFTER_COMBINE_A = auto()
|
||||
|
||||
|
||||
class DeepEPDispatcher:
|
||||
class DeepEPDispatcher(BaseDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
@@ -676,7 +627,7 @@ class DeepEPDispatcher:
|
||||
|
||||
self._stage = _Stage.INITIAL
|
||||
|
||||
def dispatch(self, *args, **kwargs) -> Tuple:
|
||||
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
||||
self.dispatch_a(*args, **kwargs)
|
||||
ret = self.dispatch_b()
|
||||
return ret
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DispatchOutputFormat(Enum):
|
||||
standard = auto()
|
||||
deepep_normal = auto()
|
||||
deepep_ll = auto()
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == DispatchOutputFormat.standard
|
||||
|
||||
def is_deepep_normal(self) -> bool:
|
||||
return self == DispatchOutputFormat.deepep_normal
|
||||
|
||||
def is_deepep_ll(self) -> bool:
|
||||
return self == DispatchOutputFormat.deepep_ll
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DispatchOutput(Protocol):
|
||||
"""Protocol for dispatch outputs in different formats."""
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat: ...
|
||||
|
||||
|
||||
class BaseDispatcherConfig(ABC):
|
||||
"""Base class for dispatcher configs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseDispatcher(ABC):
|
||||
"""Base class for dispatchers."""
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def combine(self, *args, **kwargs) -> torch.Tensor:
|
||||
pass
|
||||
19
python/sglang/srt/layers/moe/token_dispatcher/standard.py
Normal file
19
python/sglang/srt/layers/moe/token_dispatcher/standard.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
|
||||
|
||||
class StandardDispatchOutput(NamedTuple):
|
||||
"""Standard dispatch output."""
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.standard
|
||||
|
||||
|
||||
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
||||
@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
num_recv_tokens_per_expert,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
hidden_states=final_hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
self.deepep_dispatcher.dispatch_a(
|
||||
self.experts.deepep_dispatcher.dispatch_a(
|
||||
hidden_states=state.hidden_states_mlp_input,
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b(
|
||||
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_batch=state.forward_batch,
|
||||
state.hidden_states_experts_output = self.experts.moe_impl(
|
||||
dispatch_output=state.dispatch_output,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.deepep_dispatcher.combine_a(
|
||||
self.experts.deepep_dispatcher.combine_a(
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
topk_idx=state.dispatch_output.topk_idx,
|
||||
topk_weights=state.dispatch_output.topk_weights,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
state.pop("dispatch_output")
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
state.hidden_states_after_combine = (
|
||||
self.experts.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
|
||||
@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=config.num_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
num_recv_tokens_per_expert,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
hidden_states=final_hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
def op_gate(self, state):
|
||||
@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
self.deepep_dispatcher.dispatch_a(
|
||||
self.experts.deepep_dispatcher.dispatch_a(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b(
|
||||
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_batch=state.forward_batch,
|
||||
state.hidden_states_experts_output = self.experts.moe_impl(
|
||||
dispatch_output=state.dispatch_output,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.deepep_dispatcher.combine_a(
|
||||
self.experts.deepep_dispatcher.combine_a(
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
topk_idx=state.dispatch_output.topk_idx,
|
||||
topk_weights=state.dispatch_output.topk_weights,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
state.pop("dispatch_output")
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
state.hidden_states_after_combine = (
|
||||
self.experts.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
|
||||
|
||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
|
||||
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
||||
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
||||
|
||||
def dispatch(self, **kwargs):
|
||||
def dispatch(self, **kwargs) -> DispatchOutput:
|
||||
return self._execute("dispatch", **kwargs)
|
||||
|
||||
def dispatch_a(self, **kwargs):
|
||||
@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
|
||||
def dispatch_b(self, **kwargs):
|
||||
return self._execute("dispatch_b", **kwargs)
|
||||
|
||||
def combine(self, **kwargs):
|
||||
def combine(self, **kwargs) -> torch.Tensor:
|
||||
return self._execute("combine", **kwargs)
|
||||
|
||||
def combine_a(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user