From fd08c0482129626b501ed80ceab5ac61f0f849c4 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 18 May 2025 08:09:42 +0800 Subject: [PATCH] Support custom DeepEP tuning config (#6257) --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 40 ++++++++++++++++--- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server_args.py | 7 ++++ python/sglang/srt/utils.py | 7 ++++ test/srt/test_moe_deepep.py | 28 +++++++++++++ 6 files changed, 79 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 34a79f0e8..4d165dbd2 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,8 +1,11 @@ +import logging + from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM -from sglang.srt.utils import DeepEPMode +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.utils import DeepEPMode, load_json_config try: - from deep_ep import Buffer + from deep_ep import Buffer, Config from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8, @@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardMode +logger = logging.getLogger(__name__) + class DeepEPDispatchMode(IntEnum): NORMAL = auto() @@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum): class DeepEPBuffer: - _buffer = None _dispatch_mode: Optional[DeepEPDispatchMode] = None _hidden_size: Optional[int] = None @@ -60,8 +64,10 @@ class DeepEPBuffer: if deepep_mode.enable_normal(): hidden_bytes = hidden_size * param_bytes for config in ( - Buffer.get_dispatch_config(group.size()), - Buffer.get_combine_config(group.size()), + _DeepEPConfig.get_instance().normal_dispatch_config + or Buffer.get_dispatch_config(group.size()), + _DeepEPConfig.get_instance().normal_combine_config + or Buffer.get_combine_config(group.size()), ): num_nvl_bytes = max( config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), @@ -113,6 +119,28 @@ class DeepEPBuffer: cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY +class _DeepEPConfig: + _instance = None + + def __init__(self): + config_str = global_server_args_dict["deepep_config"] + if config_str: + config_parsed = load_json_config(config_str) + if torch.distributed.get_rank() == 0: + logger.info(f"Use DeepEP Config: {config_parsed}") + self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"]) + self.normal_combine_config = Config(**config_parsed["normal_combine"]) + else: + self.normal_dispatch_config = None + self.normal_combine_config = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = _DeepEPConfig() + return cls._instance + + class _DeepEPDispatcherImplBase: def __init__( self, @@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): async_finish=self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish, expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, + config=_DeepEPConfig.get_instance().normal_dispatch_config, ) return ( @@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): async_finish=self.async_finish, previous_event=previous_event, allocate_on_comm_stream=previous_event is not None, + config=_DeepEPConfig.get_instance().normal_combine_config, ) return combined_x, event diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c0464541d..83479cd59 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -77,6 +77,7 @@ global_server_args_dict = { "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_lm_head": ServerArgs.enable_dp_lm_head, "enable_ep_moe": ServerArgs.enable_ep_moe, + "deepep_config": ServerArgs.deepep_config, "enable_nan_detection": ServerArgs.enable_nan_detection, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "max_micro_batch_size": ServerArgs.max_micro_batch_size, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c246fd82d..b35620b13 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -165,6 +165,7 @@ class ModelRunner: "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, "enable_deepep_moe": server_args.enable_deepep_moe, + "deepep_config": server_args.deepep_config, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "moe_dense_tp_size": server_args.moe_dense_tp_size, "n_share_experts_fusion": server_args.n_share_experts_fusion, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fe42a7c34..62c4b990f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -169,6 +169,7 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" + deepep_config: Optional[str] = None enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -1249,6 +1250,12 @@ class ServerArgs: default="auto", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", ) + parser.add_argument( + "--deepep-config", + type=str, + default=ServerArgs.deepep_config, + help="Tuned DeepEP config suitable for your own cluster.", + ) parser.add_argument( "--n-share-experts-fusion", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e82408fa5..9cbea8252 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg): logger.info(msg) +def load_json_config(data: str): + try: + return json.loads(data) + except JSONDecodeError: + return json.loads(Path(data).read_text()) + + def dispose_tensor(x: torch.Tensor): x.set_(torch.empty((0,), device=x.device, dtype=x.dtype)) diff --git a/test/srt/test_moe_deepep.py b/test/srt/test_moe_deepep.py index a25146eb5..6504d9f8f 100644 --- a/test/srt/test_moe_deepep.py +++ b/test/srt/test_moe_deepep.py @@ -1,3 +1,5 @@ +import json +import os import unittest from types import SimpleNamespace @@ -64,8 +66,34 @@ class TestDPAttn(unittest.TestCase): "2", "--enable-dp-attention", "--enable-deepep-moe", + "--deepep-mode", + "normal", "--disable-cuda-graph", + # Test custom config + "--deepep-config", + json.dumps( + { + "normal_dispatch": { + "num_sms": 20, + "num_max_nvl_chunked_send_tokens": 16, + "num_max_nvl_chunked_recv_tokens": 256, + "num_max_rdma_chunked_send_tokens": 6, + "num_max_rdma_chunked_recv_tokens": 128, + }, + "normal_combine": { + "num_sms": 20, + "num_max_nvl_chunked_send_tokens": 6, + "num_max_nvl_chunked_recv_tokens": 256, + "num_max_rdma_chunked_send_tokens": 6, + "num_max_rdma_chunked_recv_tokens": 128, + }, + } + ), ], + env={ + "SGL_ENABLE_JIT_DEEPGEMM": "0", + **os.environ, + }, ) @classmethod