[1/N] Introduce Mooncake Backend and Mooncake EP to Support Elastic EP (#10423)
Co-authored-by: Hank Han <hanhan7630@outlook.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
get_local_ip_auto,
|
||||
is_cpu,
|
||||
is_cuda_alike,
|
||||
is_hip,
|
||||
@@ -258,11 +259,14 @@ class GroupCoordinator:
|
||||
device_group = torch.distributed.new_group(
|
||||
ranks, backend=torch_distributed_backend
|
||||
)
|
||||
# a group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
cpu_group = torch.distributed.new_group(
|
||||
ranks, backend="gloo", timeout=gloo_timeout
|
||||
)
|
||||
# a cpu_group to allow direct coordination between processes through
|
||||
# the CPU. The backend is chosen based on `torch_distributed_backend`
|
||||
if "mooncake" in torch_distributed_backend:
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu")
|
||||
else:
|
||||
cpu_group = torch.distributed.new_group(
|
||||
ranks, backend="gloo", timeout=gloo_timeout
|
||||
)
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
@@ -1410,6 +1414,17 @@ def init_distributed_environment(
|
||||
distributed_init_method,
|
||||
backend,
|
||||
)
|
||||
if "mooncake" in backend:
|
||||
try:
|
||||
from mooncake import ep as mooncake_ep
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run SGLang with Mooncake Backend."
|
||||
) from e
|
||||
mooncake_ep.set_host_ip(get_local_ip_auto())
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
assert distributed_init_method is not None, (
|
||||
"distributed_init_method must be provided when initializing "
|
||||
|
||||
@@ -59,6 +59,7 @@ logger = logging.getLogger(__name__)
|
||||
class DeepEPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
Mooncake EP shares the same class, as they expose the same interface.
|
||||
"""
|
||||
|
||||
_has_printed = False
|
||||
@@ -686,7 +687,7 @@ class DeepEPMoE(FusedMoE):
|
||||
|
||||
|
||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
|
||||
return DeepEPMoE
|
||||
|
||||
# NEW: Direct FP4 detection (bypasses EP requirements)
|
||||
|
||||
@@ -16,6 +16,11 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||
DeepEPNormalCombineInput,
|
||||
DeepEPNormalOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
|
||||
MooncakeCombineInput,
|
||||
MooncakeDispatchOutput,
|
||||
MooncakeEPDispatcher,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
@@ -30,6 +35,9 @@ __all__ = [
|
||||
"DispatchOutput",
|
||||
"DispatchOutputFormat",
|
||||
"DispatchOutputChecker",
|
||||
"MooncakeCombineInput",
|
||||
"MooncakeDispatchOutput",
|
||||
"MooncakeEPDispatcher",
|
||||
"StandardDispatchOutput",
|
||||
"StandardCombineInput",
|
||||
"DeepEPConfig",
|
||||
|
||||
394
python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
Normal file
394
python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
Normal file
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
try:
|
||||
from mooncake.mooncake_ep_buffer import Buffer
|
||||
|
||||
use_mooncake_ep = True
|
||||
except ImportError:
|
||||
use_mooncake_ep = False
|
||||
|
||||
from enum import Enum, IntEnum, auto
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MooncakeDispatchOutput(NamedTuple):
|
||||
"""Mooncake EP dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
expected_m: int
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.DEEPEP_LL
|
||||
|
||||
|
||||
assert isinstance(MooncakeDispatchOutput, DispatchOutput)
|
||||
|
||||
|
||||
class MooncakeCombineInput(NamedTuple):
|
||||
"""Mooncake EP combine input."""
|
||||
|
||||
pass
|
||||
|
||||
@property
|
||||
def format(self) -> CombineInputFormat:
|
||||
return CombineInputFormat.DEEPEP_LL
|
||||
|
||||
|
||||
assert isinstance(MooncakeCombineInput, CombineInput)
|
||||
|
||||
|
||||
_ACTIVE_RANKS: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_ep_active_ranks() -> torch.Tensor:
|
||||
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
|
||||
return _ACTIVE_RANKS
|
||||
|
||||
|
||||
class EPBuffer:
|
||||
_buffer = None
|
||||
_hidden_size: Optional[int] = None
|
||||
_num_max_dispatch_tokens_per_rank: Optional[int] = None
|
||||
_num_experts: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def get_ep_buffer(
|
||||
cls,
|
||||
group: dist.ProcessGroup,
|
||||
hidden_size: int,
|
||||
param_bytes: int,
|
||||
deepep_mode: DeepEPMode,
|
||||
num_max_dispatch_tokens_per_rank: int = -1,
|
||||
num_experts: int = -1,
|
||||
):
|
||||
if cls._buffer is not None:
|
||||
return cls._buffer
|
||||
|
||||
cls._hidden_size = hidden_size
|
||||
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
cls._num_experts = num_experts
|
||||
|
||||
num_ep_buffer_bytes = 0
|
||||
if deepep_mode.enable_normal():
|
||||
raise NotImplementedError(
|
||||
"Normal mode is not supported for Mooncake EP yet."
|
||||
)
|
||||
if deepep_mode.enable_low_latency():
|
||||
assert num_max_dispatch_tokens_per_rank != -1
|
||||
assert num_experts != -1 and num_experts % group.size() == 0
|
||||
num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
hidden_size,
|
||||
group.size(),
|
||||
num_experts,
|
||||
)
|
||||
|
||||
cls._buffer = Buffer(group, num_ep_buffer_bytes)
|
||||
return cls._buffer
|
||||
|
||||
|
||||
class _MooncakeEPDispatcherImpl:
|
||||
def __init__(
|
||||
self,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
router_topk: int,
|
||||
permute_fusion: bool,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
hidden_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
return_recv_hook: bool,
|
||||
deepep_mode: DeepEPMode,
|
||||
):
|
||||
if not use_mooncake_ep:
|
||||
raise ImportError(
|
||||
"Mooncake EP is not installed. Please install Mooncake package at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
|
||||
"with EP support to run SGLang with Mooncake EP."
|
||||
)
|
||||
self.group = group
|
||||
self.router_topk = router_topk
|
||||
self.permute_fusion = permute_fusion
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_local_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.params_dtype = params_dtype
|
||||
self.return_recv_hook = return_recv_hook
|
||||
self.deepep_mode = deepep_mode
|
||||
|
||||
self.params_bytes = 2
|
||||
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
|
||||
"SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
|
||||
)
|
||||
# Mooncake EP dispatch uses FINISHED_SUM_TAG=1024
|
||||
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
|
||||
assert self.num_max_dispatch_tokens_per_rank <= 1024
|
||||
|
||||
self.first_execution = True
|
||||
self.timeout_us = 10000000
|
||||
|
||||
global _ACTIVE_RANKS
|
||||
if _ACTIVE_RANKS is None:
|
||||
_ACTIVE_RANKS = torch.ones(
|
||||
(self.num_experts,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.active_ranks = _ACTIVE_RANKS
|
||||
|
||||
self.handle = None
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
expected_m = (
|
||||
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
||||
+ self.num_experts
|
||||
) // self.num_experts
|
||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
use_fp8=True,
|
||||
)
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
event,
|
||||
hook,
|
||||
)
|
||||
|
||||
def dispatch_b(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
event,
|
||||
hook,
|
||||
):
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
|
||||
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
|
||||
masked_m
|
||||
)
|
||||
|
||||
return MooncakeDispatchOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
def _dispatch_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||
buffer.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
self.active_ranks,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
-1 if self.first_execution else self.timeout_us,
|
||||
use_fp8=use_fp8,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
)
|
||||
return packed_recv_hidden, packed_recv_count, event, hook
|
||||
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
hidden_states, event, hook = self._combine_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
)
|
||||
return hidden_states, event, hook
|
||||
|
||||
def combine_b(self, hidden_states, event, hook):
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
return hidden_states
|
||||
|
||||
def _combine_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
combined_hidden_states, event, hook = buffer.combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.active_ranks,
|
||||
-1 if self.first_execution else self.timeout_us,
|
||||
self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
self.first_execution = False
|
||||
self.handle = None
|
||||
return combined_hidden_states, event, hook
|
||||
|
||||
def _get_buffer(self):
|
||||
return EPBuffer.get_ep_buffer(
|
||||
self.group,
|
||||
self.hidden_size,
|
||||
self.params_bytes,
|
||||
self.deepep_mode,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Stage(Enum):
|
||||
INITIAL = auto()
|
||||
AFTER_DISPATCH_A = auto()
|
||||
AFTER_DISPATCH_B = auto()
|
||||
AFTER_COMBINE_A = auto()
|
||||
|
||||
|
||||
class MooncakeEPDispatcher(BaseDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
router_topk: int,
|
||||
permute_fusion: bool = False,
|
||||
num_experts: int = None,
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
||||
async_finish: bool = False,
|
||||
return_recv_hook: bool = False,
|
||||
):
|
||||
self.deepep_mode = deepep_mode
|
||||
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
self._low_latency_dispatcher = _MooncakeEPDispatcherImpl(
|
||||
group=group,
|
||||
router_topk=router_topk,
|
||||
permute_fusion=permute_fusion,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
return_recv_hook=return_recv_hook,
|
||||
deepep_mode=deepep_mode,
|
||||
)
|
||||
if self.deepep_mode.enable_normal():
|
||||
raise NotImplementedError
|
||||
|
||||
self._stage = _Stage.INITIAL
|
||||
|
||||
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
||||
self.dispatch_a(*args, **kwargs)
|
||||
ret = self.dispatch_b()
|
||||
return ret
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
)
|
||||
self._dispatch_intermediate_state = forward_batch, inner_state
|
||||
|
||||
def dispatch_b(self):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||
forward_batch, inner_state = self._dispatch_intermediate_state
|
||||
del self._dispatch_intermediate_state
|
||||
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
||||
|
||||
def combine(self, *args, **kwargs) -> Tuple:
|
||||
self.combine_a(*args, **kwargs)
|
||||
ret = self.combine_b()
|
||||
return ret
|
||||
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional = None,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_batch).combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
)
|
||||
self._combine_intermediate_state = forward_batch, inner_state
|
||||
|
||||
def combine_b(self):
|
||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||
forward_batch, inner_state = self._combine_intermediate_state
|
||||
del self._combine_intermediate_state
|
||||
return self._get_impl(forward_batch).combine_b(*inner_state)
|
||||
|
||||
def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
||||
raise NotImplementedError
|
||||
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
||||
return self._low_latency_dispatcher
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
def _update_stage(self, old_stage, new_stage):
|
||||
assert self._stage == old_stage
|
||||
self._stage = new_stage
|
||||
@@ -24,6 +24,7 @@ class MoeA2ABackend(Enum):
|
||||
|
||||
NONE = "none"
|
||||
DEEPEP = "deepep"
|
||||
MOONCAKE = "mooncake"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
@@ -40,6 +41,9 @@ class MoeA2ABackend(Enum):
|
||||
def is_deepep(self):
|
||||
return self == MoeA2ABackend.DEEPEP
|
||||
|
||||
def is_mooncake(self):
|
||||
return self == MoeA2ABackend.MOONCAKE
|
||||
|
||||
|
||||
class MoeRunnerBackend(Enum):
|
||||
|
||||
|
||||
@@ -677,7 +677,18 @@ class ModelRunner:
|
||||
raise
|
||||
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
if self.server_args.elastic_ep_backend == "mooncake":
|
||||
backend = "mooncake"
|
||||
if self.server_args.mooncake_ib_device:
|
||||
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
|
||||
try:
|
||||
from mooncake import ep as mooncake_ep
|
||||
|
||||
mooncake_ep.set_device_filter(mooncake_ib_device)
|
||||
except:
|
||||
pass # A warning will be raised in `init_distributed_environment`
|
||||
else:
|
||||
backend = "nccl"
|
||||
elif self.device == "xpu":
|
||||
backend = "xccl"
|
||||
elif self.device == "hpu":
|
||||
@@ -885,17 +896,23 @@ class ModelRunner:
|
||||
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
||||
)
|
||||
|
||||
# Handle the case where some ranks do not finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
if self.server_args.elastic_ep_backend == "mooncake":
|
||||
# Mooncake does not support `monitored_barrier`
|
||||
dist.barrier(group=get_tp_group().cpu_group)
|
||||
else:
|
||||
# Handle the case where some ranks do not finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(
|
||||
seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
|
||||
),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
|
||||
def update_expert_location(
|
||||
self,
|
||||
|
||||
@@ -592,6 +592,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if get_moe_a2a_backend().is_deepep()
|
||||
or get_moe_a2a_backend().is_mooncake()
|
||||
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
||||
else {}
|
||||
),
|
||||
@@ -622,7 +623,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
@@ -651,7 +652,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||
self._enable_a2a_moe = (
|
||||
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
||||
)
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -668,7 +671,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
use_reduce_scatter: bool = False,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
if not self._enable_deepep_moe:
|
||||
if not self._enable_a2a_moe:
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||
if (
|
||||
self.alt_stream is not None
|
||||
|
||||
@@ -228,6 +228,8 @@ class ServerArgs:
|
||||
|
||||
# Runtime options
|
||||
device: Optional[str] = None
|
||||
elastic_ep_backend: Literal[None, "mooncake"] = None
|
||||
mooncake_ib_device: Optional[str] = None
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
pp_max_micro_batch_size: Optional[int] = None
|
||||
@@ -344,7 +346,7 @@ class ServerArgs:
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
||||
moe_a2a_backend: Literal["none", "deepep", "mooncake"] = "none"
|
||||
moe_runner_backend: str = "auto"
|
||||
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
@@ -537,7 +539,7 @@ class ServerArgs:
|
||||
|
||||
# Handle MoE configurations.
|
||||
self._handle_moe_kernel_config()
|
||||
self._handle_deepep_moe()
|
||||
self._handle_a2a_moe()
|
||||
self._handle_eplb_and_dispatch()
|
||||
self._handle_expert_distribution_metrics()
|
||||
|
||||
@@ -1091,7 +1093,7 @@ class ServerArgs:
|
||||
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
||||
)
|
||||
|
||||
def _handle_deepep_moe(self):
|
||||
def _handle_a2a_moe(self):
|
||||
if self.moe_a2a_backend == "deepep":
|
||||
if self.deepep_mode == "normal":
|
||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||
@@ -1101,6 +1103,12 @@ class ServerArgs:
|
||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
if self.moe_a2a_backend == "mooncake":
|
||||
self.ep_size = self.tp_size
|
||||
logger.warning(
|
||||
f"Mooncake MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
def _handle_eplb_and_dispatch(self):
|
||||
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
|
||||
self.expert_distribution_recorder_mode = "stat"
|
||||
@@ -1712,6 +1720,21 @@ class ServerArgs:
|
||||
default=ServerArgs.device,
|
||||
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--elastic-ep-backend",
|
||||
type=str,
|
||||
default=ServerArgs.elastic_ep_backend,
|
||||
choices=["none", "mooncake"],
|
||||
help="Specify the collective communication backend for elastic EP. Currently supports 'mooncake'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mooncake-ib-device",
|
||||
type=str,
|
||||
default=ServerArgs.mooncake_ib_device,
|
||||
help="The InfiniBand devices for Mooncake Backend transfer, accepts multiple comma-separated devices "
|
||||
"(e.g., --mooncake-ib-device mlx5_0,mlx5_1). "
|
||||
"Default is None, which triggers automatic device detection when Mooncake Backend is enabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
"--tp-size",
|
||||
@@ -2333,7 +2356,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--moe-a2a-backend",
|
||||
type=str,
|
||||
choices=["none", "deepep"],
|
||||
choices=["none", "deepep", "mooncake"],
|
||||
default=ServerArgs.moe_a2a_backend,
|
||||
help="Choose the backend for MoE A2A.",
|
||||
)
|
||||
|
||||
@@ -20,7 +20,10 @@ from sglang.srt.layers.moe import (
|
||||
get_tbo_token_distribution_threshold,
|
||||
is_tbo_enabled,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
DeepEPDispatcher,
|
||||
MooncakeEPDispatcher,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
@@ -363,7 +366,7 @@ class TboDPAttentionPreparer:
|
||||
):
|
||||
|
||||
deepep_mode = get_deepep_mode()
|
||||
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||
enable_a2a_moe = not get_moe_a2a_backend().is_none()
|
||||
enable_two_batch_overlap = is_tbo_enabled()
|
||||
|
||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||
@@ -392,7 +395,7 @@ class TboDPAttentionPreparer:
|
||||
local_batch.forward_mode.is_extend()
|
||||
and not local_batch.forward_mode.is_target_verify()
|
||||
)
|
||||
and enable_deepep_moe
|
||||
and enable_a2a_moe
|
||||
and (resolved_deepep_mode.is_low_latency())
|
||||
)
|
||||
else:
|
||||
@@ -968,9 +971,14 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
|
||||
class MaybeTboDeepEPDispatcher:
|
||||
def __init__(self, **kwargs):
|
||||
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
|
||||
self._inners = [
|
||||
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
||||
]
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
self._inners = [
|
||||
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
||||
]
|
||||
elif get_moe_a2a_backend().is_mooncake():
|
||||
self._inners = [
|
||||
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
||||
]
|
||||
|
||||
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
|
||||
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user