diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 3673ba4d8..aad5bbc05 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -134,6 +134,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None | +| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None | +| `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None | | `--tp-size` | The tensor parallelism size. | 1 | | `--pp-size` | The pipeline parallelism size. | 1 | | `--pp-max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None | @@ -246,7 +248,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--ep-size` | The expert parallelism size. | 1 | -| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none | +| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep` or `mooncake`. | none | | `--moe-runner-backend` | Select the runner backend for MoE. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 9a7ddd825..7e18d06db 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -43,6 +43,7 @@ from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, get_int_env_var, + get_local_ip_auto, is_cpu, is_cuda_alike, is_hip, @@ -258,11 +259,14 @@ class GroupCoordinator: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend ) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - cpu_group = torch.distributed.new_group( - ranks, backend="gloo", timeout=gloo_timeout - ) + # a cpu_group to allow direct coordination between processes through + # the CPU. The backend is chosen based on `torch_distributed_backend` + if "mooncake" in torch_distributed_backend: + cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu") + else: + cpu_group = torch.distributed.new_group( + ranks, backend="gloo", timeout=gloo_timeout + ) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -1410,6 +1414,17 @@ def init_distributed_environment( distributed_init_method, backend, ) + if "mooncake" in backend: + try: + from mooncake import ep as mooncake_ep + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run SGLang with Mooncake Backend." + ) from e + mooncake_ep.set_host_ip(get_local_ip_auto()) + if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index bc7251989..3f68ad563 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -59,6 +59,7 @@ logger = logging.getLogger(__name__) class DeepEPMoE(FusedMoE): """ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) + Mooncake EP shares the same class, as they expose the same interface. """ _has_printed = False @@ -686,7 +687,7 @@ class DeepEPMoE(FusedMoE): def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): - if get_moe_a2a_backend().is_deepep(): + if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): return DeepEPMoE # NEW: Direct FP4 detection (bypasses EP requirements) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index e1dbcdd44..7526f73de 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -16,6 +16,11 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import ( DeepEPNormalCombineInput, DeepEPNormalOutput, ) +from sglang.srt.layers.moe.token_dispatcher.mooncake import ( + MooncakeCombineInput, + MooncakeDispatchOutput, + MooncakeEPDispatcher, +) from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardCombineInput, StandardDispatchOutput, @@ -30,6 +35,9 @@ __all__ = [ "DispatchOutput", "DispatchOutputFormat", "DispatchOutputChecker", + "MooncakeCombineInput", + "MooncakeDispatchOutput", + "MooncakeEPDispatcher", "StandardDispatchOutput", "StandardCombineInput", "DeepEPConfig", diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py new file mode 100644 index 000000000..d6d561865 --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import NamedTuple, Optional, Tuple + +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.moe.token_dispatcher.base import ( + BaseDispatcher, + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, +) +from sglang.srt.layers.moe.utils import DeepEPMode +from sglang.srt.utils import get_int_env_var + +try: + from mooncake.mooncake_ep_buffer import Buffer + + use_mooncake_ep = True +except ImportError: + use_mooncake_ep = False + +from enum import Enum, IntEnum, auto + +import torch +import torch.distributed as dist + +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +logger = logging.getLogger(__name__) + + +class MooncakeDispatchOutput(NamedTuple): + """Mooncake EP dispatch output.""" + + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] + topk_idx: torch.Tensor + topk_weights: torch.Tensor + masked_m: torch.Tensor + expected_m: int + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.DEEPEP_LL + + +assert isinstance(MooncakeDispatchOutput, DispatchOutput) + + +class MooncakeCombineInput(NamedTuple): + """Mooncake EP combine input.""" + + pass + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_LL + + +assert isinstance(MooncakeCombineInput, CombineInput) + + +_ACTIVE_RANKS: Optional[torch.Tensor] = None + + +def get_ep_active_ranks() -> torch.Tensor: + assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized" + return _ACTIVE_RANKS + + +class EPBuffer: + _buffer = None + _hidden_size: Optional[int] = None + _num_max_dispatch_tokens_per_rank: Optional[int] = None + _num_experts: Optional[int] = None + + @classmethod + def get_ep_buffer( + cls, + group: dist.ProcessGroup, + hidden_size: int, + param_bytes: int, + deepep_mode: DeepEPMode, + num_max_dispatch_tokens_per_rank: int = -1, + num_experts: int = -1, + ): + if cls._buffer is not None: + return cls._buffer + + cls._hidden_size = hidden_size + cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + cls._num_experts = num_experts + + num_ep_buffer_bytes = 0 + if deepep_mode.enable_normal(): + raise NotImplementedError( + "Normal mode is not supported for Mooncake EP yet." + ) + if deepep_mode.enable_low_latency(): + assert num_max_dispatch_tokens_per_rank != -1 + assert num_experts != -1 and num_experts % group.size() == 0 + num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint( + num_max_dispatch_tokens_per_rank, + hidden_size, + group.size(), + num_experts, + ) + + cls._buffer = Buffer(group, num_ep_buffer_bytes) + return cls._buffer + + +class _MooncakeEPDispatcherImpl: + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, + return_recv_hook: bool, + deepep_mode: DeepEPMode, + ): + if not use_mooncake_ep: + raise ImportError( + "Mooncake EP is not installed. Please install Mooncake package at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " + "with EP support to run SGLang with Mooncake EP." + ) + self.group = group + self.router_topk = router_topk + self.permute_fusion = permute_fusion + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.params_dtype = params_dtype + self.return_recv_hook = return_recv_hook + self.deepep_mode = deepep_mode + + self.params_bytes = 2 + self.num_max_dispatch_tokens_per_rank = get_int_env_var( + "SGLANG_MOONCAKE_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128 + ) + # Mooncake EP dispatch uses FINISHED_SUM_TAG=1024 + # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it + assert self.num_max_dispatch_tokens_per_rank <= 1024 + + self.first_execution = True + self.timeout_us = 10000000 + + global _ACTIVE_RANKS + if _ACTIVE_RANKS is None: + _ACTIVE_RANKS = torch.ones( + (self.num_experts,), dtype=torch.int32, device="cuda" + ) + self.active_ranks = _ACTIVE_RANKS + + self.handle = None + + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + buffer = self._get_buffer() + topk_idx = topk_idx.to(torch.int64) + expected_m = ( + hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + + self.num_experts + ) // self.num_experts + hidden_states, masked_m, event, hook = self._dispatch_core( + hidden_states, + topk_idx, + use_fp8=True, + ) + return ( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, + ) + + def dispatch_b( + self, + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, + ): + hook() if self.return_recv_hook else event.current_stream_wait() + + get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency( + masked_m + ) + + return MooncakeDispatchOutput( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + ) + + def _dispatch_core( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + use_fp8: bool = False, + ): + buffer = self._get_buffer() + packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( + buffer.dispatch( + hidden_states, + topk_idx, + self.active_ranks, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + -1 if self.first_execution else self.timeout_us, + use_fp8=use_fp8, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + ) + ) + return packed_recv_hidden, packed_recv_count, event, hook + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + hidden_states, event, hook = self._combine_core( + hidden_states, + topk_idx, + topk_weights, + ) + return hidden_states, event, hook + + def combine_b(self, hidden_states, event, hook): + hook() if self.return_recv_hook else event.current_stream_wait() + return hidden_states + + def _combine_core( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + buffer = self._get_buffer() + combined_hidden_states, event, hook = buffer.combine( + hidden_states, + topk_idx, + topk_weights, + self.active_ranks, + -1 if self.first_execution else self.timeout_us, + self.handle, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + ) + self.first_execution = False + self.handle = None + return combined_hidden_states, event, hook + + def _get_buffer(self): + return EPBuffer.get_ep_buffer( + self.group, + self.hidden_size, + self.params_bytes, + self.deepep_mode, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + ) + + +@dataclass +class _Stage(Enum): + INITIAL = auto() + AFTER_DISPATCH_A = auto() + AFTER_DISPATCH_B = auto() + AFTER_COMBINE_A = auto() + + +class MooncakeEPDispatcher(BaseDispatcher): + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.AUTO, + async_finish: bool = False, + return_recv_hook: bool = False, + ): + self.deepep_mode = deepep_mode + + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher = _MooncakeEPDispatcherImpl( + group=group, + router_topk=router_topk, + permute_fusion=permute_fusion, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + return_recv_hook=return_recv_hook, + deepep_mode=deepep_mode, + ) + if self.deepep_mode.enable_normal(): + raise NotImplementedError + + self._stage = _Stage.INITIAL + + def dispatch(self, *args, **kwargs) -> DispatchOutput: + self.dispatch_a(*args, **kwargs) + ret = self.dispatch_b() + return ret + + def dispatch_a( + self, + hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, + ): + self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) + inner_state = self._get_impl(forward_batch).dispatch_a( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + ) + self._dispatch_intermediate_state = forward_batch, inner_state + + def dispatch_b(self): + self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) + forward_batch, inner_state = self._dispatch_intermediate_state + del self._dispatch_intermediate_state + return self._get_impl(forward_batch).dispatch_b(*inner_state) + + def combine(self, *args, **kwargs) -> Tuple: + self.combine_a(*args, **kwargs) + ret = self.combine_b() + return ret + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, + overlap_args: Optional = None, + ): + self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) + inner_state = self._get_impl(forward_batch).combine_a( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + ) + self._combine_intermediate_state = forward_batch, inner_state + + def combine_b(self): + self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) + forward_batch, inner_state = self._combine_intermediate_state + del self._combine_intermediate_state + return self._get_impl(forward_batch).combine_b(*inner_state) + + def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl: + resolved_deepep_mode = self.deepep_mode.resolve( + forward_batch.is_extend_in_batch + ) + if resolved_deepep_mode == DeepEPMode.NORMAL: + raise NotImplementedError + elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: + return self._low_latency_dispatcher + else: + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + + def _update_stage(self, old_stage, new_stage): + assert self._stage == old_stage + self._stage = new_stage diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 624249f4a..f71192236 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -24,6 +24,7 @@ class MoeA2ABackend(Enum): NONE = "none" DEEPEP = "deepep" + MOONCAKE = "mooncake" @classmethod def _missing_(cls, value): @@ -40,6 +41,9 @@ class MoeA2ABackend(Enum): def is_deepep(self): return self == MoeA2ABackend.DEEPEP + def is_mooncake(self): + return self == MoeA2ABackend.MOONCAKE + class MoeRunnerBackend(Enum): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index aceb572c9..b4eab1d5a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -677,7 +677,18 @@ class ModelRunner: raise if self.device == "cuda": - backend = "nccl" + if self.server_args.elastic_ep_backend == "mooncake": + backend = "mooncake" + if self.server_args.mooncake_ib_device: + mooncake_ib_device = self.server_args.mooncake_ib_device.split(",") + try: + from mooncake import ep as mooncake_ep + + mooncake_ep.set_device_filter(mooncake_ib_device) + except: + pass # A warning will be raised in `init_distributed_environment` + else: + backend = "nccl" elif self.device == "xpu": backend = "xccl" elif self.device == "hpu": @@ -885,17 +896,23 @@ class ModelRunner: f"mem usage={self.weight_load_mem_usage:.2f} GB." ) - # Handle the case where some ranks do not finish loading. - try: - dist.monitored_barrier( - group=get_tp_group().cpu_group, - timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S), - wait_all_ranks=True, - ) - except RuntimeError: - raise ValueError( - f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." - ) from None + if self.server_args.elastic_ep_backend == "mooncake": + # Mooncake does not support `monitored_barrier` + dist.barrier(group=get_tp_group().cpu_group) + else: + # Handle the case where some ranks do not finish loading. + try: + dist.monitored_barrier( + group=get_tp_group().cpu_group, + timeout=datetime.timedelta( + seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S + ), + wait_all_ranks=True, + ) + except RuntimeError: + raise ValueError( + f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." + ) from None def update_expert_location( self, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 454e08585..3100e1490 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -592,6 +592,7 @@ class DeepseekV2MoE(nn.Module): **( dict(tp_rank=0, tp_size=1) if get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() or should_use_flashinfer_cutlass_moe_fp4_allgather() else {} ), @@ -622,7 +623,7 @@ class DeepseekV2MoE(nn.Module): self.top_k = config.num_experts_per_tok - if get_moe_a2a_backend().is_deepep(): + if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( @@ -651,7 +652,9 @@ class DeepseekV2MoE(nn.Module): return_recv_hook=True, ) - self._enable_deepep_moe = get_moe_a2a_backend().is_deepep() + self._enable_a2a_moe = ( + get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() + ) def get_moe_weights(self): return [ @@ -668,7 +671,7 @@ class DeepseekV2MoE(nn.Module): use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: - if not self._enable_deepep_moe: + if not self._enable_a2a_moe: DUAL_STREAM_TOKEN_THRESHOLD = 1024 if ( self.alt_stream is not None diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b777eebb7..8a43567ff 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -228,6 +228,8 @@ class ServerArgs: # Runtime options device: Optional[str] = None + elastic_ep_backend: Literal[None, "mooncake"] = None + mooncake_ib_device: Optional[str] = None tp_size: int = 1 pp_size: int = 1 pp_max_micro_batch_size: Optional[int] = None @@ -344,7 +346,7 @@ class ServerArgs: # Expert parallelism ep_size: int = 1 - moe_a2a_backend: Literal["none", "deepep"] = "none" + moe_a2a_backend: Literal["none", "deepep", "mooncake"] = "none" moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False @@ -537,7 +539,7 @@ class ServerArgs: # Handle MoE configurations. self._handle_moe_kernel_config() - self._handle_deepep_moe() + self._handle_a2a_moe() self._handle_eplb_and_dispatch() self._handle_expert_distribution_metrics() @@ -1091,7 +1093,7 @@ class ServerArgs: "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." ) - def _handle_deepep_moe(self): + def _handle_a2a_moe(self): if self.moe_a2a_backend == "deepep": if self.deepep_mode == "normal": logger.warning("Cuda graph is disabled because deepep_mode=`normal`") @@ -1101,6 +1103,12 @@ class ServerArgs: f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.moe_a2a_backend == "mooncake": + self.ep_size = self.tp_size + logger.warning( + f"Mooncake MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + def _handle_eplb_and_dispatch(self): if self.enable_eplb and (self.expert_distribution_recorder_mode is None): self.expert_distribution_recorder_mode = "stat" @@ -1712,6 +1720,21 @@ class ServerArgs: default=ServerArgs.device, help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.", ) + parser.add_argument( + "--elastic-ep-backend", + type=str, + default=ServerArgs.elastic_ep_backend, + choices=["none", "mooncake"], + help="Specify the collective communication backend for elastic EP. Currently supports 'mooncake'.", + ) + parser.add_argument( + "--mooncake-ib-device", + type=str, + default=ServerArgs.mooncake_ib_device, + help="The InfiniBand devices for Mooncake Backend transfer, accepts multiple comma-separated devices " + "(e.g., --mooncake-ib-device mlx5_0,mlx5_1). " + "Default is None, which triggers automatic device detection when Mooncake Backend is enabled.", + ) parser.add_argument( "--tensor-parallel-size", "--tp-size", @@ -2333,7 +2356,7 @@ class ServerArgs: parser.add_argument( "--moe-a2a-backend", type=str, - choices=["none", "deepep"], + choices=["none", "deepep", "mooncake"], default=ServerArgs.moe_a2a_backend, help="Choose the backend for MoE A2A.", ) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index a5485e8e9..b09c72dae 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -20,7 +20,10 @@ from sglang.srt.layers.moe import ( get_tbo_token_distribution_threshold, is_tbo_enabled, ) -from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher +from sglang.srt.layers.moe.token_dispatcher import ( + DeepEPDispatcher, + MooncakeEPDispatcher, +) from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( @@ -363,7 +366,7 @@ class TboDPAttentionPreparer: ): deepep_mode = get_deepep_mode() - enable_deepep_moe = get_moe_a2a_backend().is_deepep() + enable_a2a_moe = not get_moe_a2a_backend().is_none() enable_two_batch_overlap = is_tbo_enabled() self.enable_two_batch_overlap = enable_two_batch_overlap @@ -392,7 +395,7 @@ class TboDPAttentionPreparer: local_batch.forward_mode.is_extend() and not local_batch.forward_mode.is_target_verify() ) - and enable_deepep_moe + and enable_a2a_moe and (resolved_deepep_mode.is_low_latency()) ) else: @@ -968,9 +971,14 @@ def _model_forward_tbo_merge_outputs(output_a, output_b): class MaybeTboDeepEPDispatcher: def __init__(self, **kwargs): num_inner_dispatchers = 2 if is_tbo_enabled() else 1 - self._inners = [ - DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) - ] + if get_moe_a2a_backend().is_deepep(): + self._inners = [ + DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] + elif get_moe_a2a_backend().is_mooncake(): + self._inners = [ + MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) diff --git a/scripts/ci/ci_install_deepep.sh b/scripts/ci/ci_install_deepep.sh index d92b7fbb3..8c1cd95f4 100755 --- a/scripts/ci/ci_install_deepep.sh +++ b/scripts/ci/ci_install_deepep.sh @@ -10,6 +10,10 @@ export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" export PATH="${NVSHMEM_DIR}/bin:$PATH" export CUDA_HOME=/usr/local/cuda +# Install Mooncake+EP +curl -L https://cloud.tsinghua.edu.cn/f/c22ec766545e48bf99e8/?dl=1 -o mooncake_transfer_engine-0.3.6.post1+ep-cp310-cp310-manylinux_2_17_x86_64.manylinux_2_35_x86_64.whl +UV_SYSTEM_PYTHON=true uv pip install mooncake_transfer_engine-0.3.6.post1+ep-cp310-cp310-manylinux_2_17_x86_64.manylinux_2_35_x86_64.whl + if python3 -c "import deep_ep" >/dev/null 2>&1; then echo "deep_ep is already installed or importable. Skipping installation." exit 0 diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py new file mode 100644 index 000000000..111260a8c --- /dev/null +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -0,0 +1,286 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestPureDP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", + "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "512", + "--mem-fraction-static", + "0.5", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestHybridDPTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", + "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", + "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestNoGatherdBuffer(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", + "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "32", + "--max-running-requests", + "512", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestTBO(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", + "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--enable-two-batch-overlap", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "512", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 96289a3df..69ae5b9c5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -171,6 +171,7 @@ suites = { ], "per-commit-4-gpu-deepep": [ TestFile("ep/test_deepep_small.py", 531), + TestFile("ep/test_mooncake_ep_small.py", 450), ], "per-commit-8-gpu-deepep": [ TestFile("ep/test_deepep_large.py", 338),