Revert "[1/2][resubmit] sgl-kernel: Fuse routed scaling factor into m… (#9035)
This commit is contained in:
@@ -132,7 +132,6 @@ class TopK(CustomOp):
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||
@@ -148,9 +147,6 @@ class TopK(CustomOp):
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.apply_routed_scaling_factor_on_output = (
|
||||
apply_routed_scaling_factor_on_output
|
||||
)
|
||||
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
|
||||
@@ -211,7 +207,6 @@ class TopK(CustomOp):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
@@ -381,7 +376,6 @@ def grouped_topk_gpu(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -429,8 +423,6 @@ def grouped_topk_gpu(
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
if apply_routed_scaling_factor_on_output:
|
||||
topk_weights *= routed_scaling_factor
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
@@ -479,7 +471,6 @@ def biased_grouped_topk_impl(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -531,8 +522,6 @@ def biased_grouped_topk_impl(
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
if apply_routed_scaling_factor_on_output:
|
||||
topk_weights *= routed_scaling_factor
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
@@ -575,10 +564,7 @@ def biased_grouped_topk_gpu(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
# TODO(trevor-m): Remove once sgl-kernel is updated
|
||||
assert not apply_routed_scaling_factor_on_output
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
), "routed_scaling_factor is required for biased_grouped_topk"
|
||||
@@ -597,8 +583,6 @@ def biased_grouped_topk_gpu(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
# TODO(trevor-m): Uncomment once sgl-kernel is updated
|
||||
# apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
# TODO merge into kernel
|
||||
if (expert_location_dispatch_info is not None) or (
|
||||
@@ -609,7 +593,6 @@ def biased_grouped_topk_gpu(
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
elif _use_aiter:
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
token = gating_output.shape[0]
|
||||
device = gating_output.device
|
||||
assert (
|
||||
@@ -641,7 +624,6 @@ def biased_grouped_topk_gpu(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
|
||||
@@ -701,7 +683,6 @@ def select_experts(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
) -> TopKOutput:
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
@@ -727,7 +708,6 @@ def select_experts(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -742,14 +722,12 @@ def select_experts(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||
assert expert_location_dispatch_info is None
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -757,7 +735,6 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
# Qwen3MOE uses fused_topk
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
@@ -772,7 +749,6 @@ def select_experts(
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||
assert expert_location_dispatch_info is None
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
Reference in New Issue
Block a user