From a589a0716774196d437bdbfe282283e593f0882a Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Sun, 20 Jul 2025 13:13:46 +0800 Subject: [PATCH] fix moe gate dtype, fix tbo, fix fake dispatch (#7825) --- python/sglang/srt/eplb/expert_location_dispatch.py | 2 +- python/sglang/srt/layers/moe/topk.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location_dispatch.py b/python/sglang/srt/eplb/expert_location_dispatch.py index 36224eee7..8d2160b6e 100644 --- a/python/sglang/srt/eplb/expert_location_dispatch.py +++ b/python/sglang/srt/eplb/expert_location_dispatch.py @@ -66,7 +66,7 @@ def transform_select_experts_inputs( info: Optional[ExpertLocationDispatchInfo], ): if (info is not None) and (info.ep_dispatch_algorithm == "fake"): - router_logits = torch.randn_like(router_logits) + router_logits.uniform_(5, 10) if correction_bias is not None: correction_bias = torch.zeros_like(correction_bias) return router_logits, correction_bias diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index bb3cf6515..c3ae9af25 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -499,7 +499,7 @@ def biased_grouped_topk_gpu( and is_power_of_two(correction_bias.shape[0]) ): topk_weights, topk_ids = moe_fused_gate( - gating_output, + gating_output.to(dtype=torch.float32), correction_bias, num_expert_group, topk_group, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9ec5db926..a65337945 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -229,7 +229,7 @@ class MoEGate(nn.Module): ) if config.topk_method == "noaux_tc": self.e_score_correction_bias = nn.Parameter( - torch.empty((config.n_routed_experts)) + torch.empty((config.n_routed_experts), dtype=torch.float32) ) else: self.e_score_correction_bias = None