diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 5b22f5a1f..9af3b8a2b 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -213,14 +213,6 @@ class TopK(CustomOp): if use_grouped_topk: assert num_expert_group is not None and topk_group is not None - if ( - quant_config is not None - and quant_config.get_name() == "modelopt_fp4" - and should_use_flashinfer_trtllm_moe() - ): - # https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643 - correction_bias = correction_bias.to(torch.bfloat16) - self.topk_config = TopKConfig( top_k=top_k, use_grouped_topk=use_grouped_topk, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c1573d8a2..c46655f56 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -65,6 +65,7 @@ from sglang.srt.layers.moe import ( get_deepep_mode, get_moe_a2a_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, + should_use_flashinfer_trtllm_moe, ) from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE @@ -269,6 +270,7 @@ class MoEGate(nn.Module): def __init__( self, config, + quant_config, prefix: str = "", is_nextn: bool = False, ): @@ -278,8 +280,15 @@ class MoEGate(nn.Module): torch.empty((config.n_routed_experts, config.hidden_size)) ) if config.topk_method == "noaux_tc": + correction_bias_dtype = ( + torch.bfloat16 + if quant_config is not None + and quant_config.get_name() == "modelopt_fp4" + and should_use_flashinfer_trtllm_moe() + else torch.float32 + ) self.e_score_correction_bias = nn.Parameter( - torch.empty((config.n_routed_experts), dtype=torch.float32) + torch.empty((config.n_routed_experts), dtype=correction_bias_dtype) ) else: self.e_score_correction_bias = None @@ -354,7 +363,10 @@ class DeepseekV2MoE(nn.Module): ) self.gate = MoEGate( - config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn + config=config, + quant_config=quant_config, + prefix=add_prefix("gate", prefix), + is_nextn=is_nextn, ) self.experts = get_moe_impl_class(quant_config)(