[1/2][resubmit] sgl-kernel: Fuse routed scaling factor into moe_fused_gate (select_experts) (#8770)

This commit is contained in:
Trevor Morris
2025-08-08 17:55:06 -07:00
committed by GitHub
parent f352b793be
commit 591c232f7c
6 changed files with 62 additions and 12 deletions

View File

@@ -132,6 +132,7 @@ class TopK(CustomOp):
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
@@ -147,6 +148,9 @@ class TopK(CustomOp):
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
self.apply_routed_scaling_factor_on_output = (
apply_routed_scaling_factor_on_output
)
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
@@ -207,6 +211,7 @@ class TopK(CustomOp):
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output,
)
def forward_cpu(
@@ -375,6 +380,7 @@ def grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -422,6 +428,8 @@ def grouped_topk_gpu(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
@@ -468,6 +476,7 @@ def biased_grouped_topk_impl(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -519,6 +528,8 @@ def biased_grouped_topk_impl(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
@@ -561,7 +572,10 @@ def biased_grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
# TODO(trevor-m): Remove once sgl-kernel is updated
assert not apply_routed_scaling_factor_on_output
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
@@ -580,6 +594,8 @@ def biased_grouped_topk_gpu(
topk,
num_fused_shared_experts,
routed_scaling_factor,
# TODO(trevor-m): Uncomment once sgl-kernel is updated
# apply_routed_scaling_factor_on_output,
)
# TODO merge into kernel
if (expert_location_dispatch_info is not None) or (
@@ -590,6 +606,7 @@ def biased_grouped_topk_gpu(
)
return topk_weights, topk_ids
elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0]
device = gating_output.device
assert (
@@ -621,6 +638,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
@@ -680,6 +698,7 @@ def select_experts(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
) -> TopKOutput:
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
@@ -705,6 +724,7 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
@@ -719,12 +739,14 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
elif torch_native and custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -732,6 +754,7 @@ def select_experts(
renormalize=renormalize,
)
elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
# Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
@@ -746,6 +769,7 @@ def select_experts(
num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function"
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,