From a40aecc5a3a5413bde543fe88f221225673bb605 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 21 May 2025 17:04:33 +0800 Subject: [PATCH] Fix num_qps_per_rank computation when providing custom DeepEP configuration (#6468) --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index b647f456b..fe9fbad67 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -67,9 +67,9 @@ class DeepEPBuffer: if deepep_mode.enable_normal(): hidden_bytes = hidden_size * param_bytes for config in ( - _DeepEPConfig.get_instance().normal_dispatch_config + DeepEPConfig.get_instance().normal_dispatch_config or Buffer.get_dispatch_config(group.size()), - _DeepEPConfig.get_instance().normal_combine_config + DeepEPConfig.get_instance().normal_combine_config or Buffer.get_combine_config(group.size()), ): num_nvl_bytes = max( @@ -97,7 +97,12 @@ class DeepEPBuffer: num_nvl_bytes, num_rdma_bytes, low_latency_mode=deepep_mode.enable_low_latency(), - num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)), + num_qps_per_rank=( + max( + num_experts // group.size(), + DeepEPConfig.get_instance().num_sms // 2, + ) + ), ) return cls._buffer @@ -122,7 +127,7 @@ class DeepEPBuffer: cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY -class _DeepEPConfig: +class DeepEPConfig: _instance = None def __init__(self): @@ -131,16 +136,23 @@ class _DeepEPConfig: config_parsed = load_json_config(config_str) if torch.distributed.get_rank() == 0: logger.info(f"Use DeepEP Config: {config_parsed}") - self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"]) - self.normal_combine_config = Config(**config_parsed["normal_combine"]) + config_dispatch = config_parsed["normal_dispatch"] + config_combine = config_parsed["normal_combine"] + + self.normal_dispatch_config = Config(**config_dispatch) + self.normal_combine_config = Config(**config_combine) + + assert config_dispatch["num_sms"] == config_combine["num_sms"] + self.num_sms = config_dispatch["num_sms"] else: self.normal_dispatch_config = None self.normal_combine_config = None + self.num_sms = Buffer.num_sms @classmethod def get_instance(cls): if cls._instance is None: - cls._instance = _DeepEPConfig() + cls._instance = DeepEPConfig() return cls._instance @@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): async_finish=self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish, expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, - config=_DeepEPConfig.get_instance().normal_dispatch_config, + config=DeepEPConfig.get_instance().normal_dispatch_config, ) get_global_expert_distribution_recorder().on_deepep_dispatch_normal( @@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): async_finish=self.async_finish, previous_event=previous_event, allocate_on_comm_stream=previous_event is not None, - config=_DeepEPConfig.get_instance().normal_combine_config, + config=DeepEPConfig.get_instance().normal_combine_config, ) return combined_x, event