[2/2] Fuse routed scaling factor into select_experts (#8690)

This commit is contained in:
Trevor Morris
2025-08-20 15:10:16 -07:00
committed by GitHub
parent f96413c444
commit a91e90d9a3
6 changed files with 55 additions and 25 deletions

View File

@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
],
)
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True])
def test_moe_fused_gate_combined(
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32
@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk=topk,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
ref_output, ref_indices = biased_grouped_topk(
scores,
@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension