Refactor TopK to ensure readability and extensibility (#9338)

This commit is contained in:
Cheng Wan
2025-09-14 19:16:25 -07:00
committed by GitHub
parent b7d385e812
commit 4844fac91d
14 changed files with 52 additions and 47 deletions

View File

@@ -65,14 +65,10 @@ from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
_is_fp4_quantization_enabled,
)
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
@@ -375,10 +371,6 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("experts", prefix),
)
correction_bias = self.gate.e_score_correction_bias
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe():
correction_bias = correction_bias.to(torch.bfloat16)
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
@@ -386,10 +378,13 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=correction_bias,
correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
force_topk=quant_config is None,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
)
self.shared_experts_is_int8 = False