Revert "[1/2][resubmit] sgl-kernel: Fuse routed scaling factor into m… (#9035)
This commit is contained in:
@@ -132,7 +132,6 @@ class TopK(CustomOp):
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||
@@ -148,9 +147,6 @@ class TopK(CustomOp):
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.apply_routed_scaling_factor_on_output = (
|
||||
apply_routed_scaling_factor_on_output
|
||||
)
|
||||
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
|
||||
@@ -211,7 +207,6 @@ class TopK(CustomOp):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
@@ -381,7 +376,6 @@ def grouped_topk_gpu(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -429,8 +423,6 @@ def grouped_topk_gpu(
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
if apply_routed_scaling_factor_on_output:
|
||||
topk_weights *= routed_scaling_factor
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
@@ -479,7 +471,6 @@ def biased_grouped_topk_impl(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -531,8 +522,6 @@ def biased_grouped_topk_impl(
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
if apply_routed_scaling_factor_on_output:
|
||||
topk_weights *= routed_scaling_factor
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
@@ -575,10 +564,7 @@ def biased_grouped_topk_gpu(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
# TODO(trevor-m): Remove once sgl-kernel is updated
|
||||
assert not apply_routed_scaling_factor_on_output
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
), "routed_scaling_factor is required for biased_grouped_topk"
|
||||
@@ -597,8 +583,6 @@ def biased_grouped_topk_gpu(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
# TODO(trevor-m): Uncomment once sgl-kernel is updated
|
||||
# apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
# TODO merge into kernel
|
||||
if (expert_location_dispatch_info is not None) or (
|
||||
@@ -609,7 +593,6 @@ def biased_grouped_topk_gpu(
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
elif _use_aiter:
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
token = gating_output.shape[0]
|
||||
device = gating_output.device
|
||||
assert (
|
||||
@@ -641,7 +624,6 @@ def biased_grouped_topk_gpu(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
|
||||
@@ -701,7 +683,6 @@ def select_experts(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
) -> TopKOutput:
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
@@ -727,7 +708,6 @@ def select_experts(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -742,14 +722,12 @@ def select_experts(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||
assert expert_location_dispatch_info is None
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -757,7 +735,6 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
# Qwen3MOE uses fused_topk
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
@@ -772,7 +749,6 @@ def select_experts(
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||
assert expert_location_dispatch_info is None
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def(
|
||||
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
|
||||
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
|
||||
"num_fused_shared_experts, float routed_scaling_factor) -> "
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
|
||||
@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output,
|
||||
Params params) {
|
||||
int tidx = threadIdx.x;
|
||||
int64_t thread_row =
|
||||
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
|
||||
for (int ii = 0; ii < topk; ++ii) {
|
||||
int64_t const idx = topk * thread_row + ii;
|
||||
output_ptr[idx] = output_ptr[idx] / output_sum;
|
||||
if (apply_routed_scaling_factor_on_output) {
|
||||
output_ptr[idx] *= routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
double routed_scaling_factor) {
|
||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||
moe_fused_gate_impl<T>(
|
||||
input,
|
||||
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
params);
|
||||
}
|
||||
|
||||
@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
topk_group, \
|
||||
topk, \
|
||||
num_fused_shared_experts, \
|
||||
routed_scaling_factor, \
|
||||
apply_routed_scaling_factor_on_output); \
|
||||
routed_scaling_factor); \
|
||||
dispatched = true; \
|
||||
} while (0)
|
||||
|
||||
@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
double routed_scaling_factor) {
|
||||
KernelParamsDynamic params;
|
||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
|
||||
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
params);
|
||||
}
|
||||
|
||||
@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
double routed_scaling_factor) {
|
||||
int64_t num_rows = input.size(0);
|
||||
int32_t num_experts = input.size(1);
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
routed_scaling_factor);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
routed_scaling_factor);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
routed_scaling_factor);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||
}
|
||||
|
||||
@@ -243,8 +243,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output);
|
||||
double routed_scaling_factor);
|
||||
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output,
|
||||
|
||||
@@ -44,7 +44,6 @@ def moe_fused_gate(
|
||||
topk,
|
||||
num_fused_shared_experts=0,
|
||||
routed_scaling_factor=0,
|
||||
apply_routed_scaling_factor_on_output=False,
|
||||
):
|
||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
||||
@@ -52,13 +51,8 @@ def moe_fused_gate(
|
||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
||||
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
# num_fused_shared_experts: if > 0, the last several experts will be
|
||||
# replaced with shared experts. the shared experts will be divided by the
|
||||
# routed_scaling_factor - this is intended to cancel out later when routed+shared
|
||||
# output is scaled so that shared experts are not scaled.
|
||||
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
|
||||
# apply_routed_scaling_factor_on_output: if true, output will be
|
||||
# scaled by the routed_scaling_factor
|
||||
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
|
||||
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
|
||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||
input_tensor,
|
||||
bias,
|
||||
@@ -67,7 +61,6 @@ def moe_fused_gate(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
|
||||
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False])
|
||||
def test_moe_fused_gate_combined(
|
||||
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
|
||||
):
|
||||
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
|
||||
num_experts, num_expert_group, topk_group, topk = params
|
||||
dtype = torch.float32
|
||||
|
||||
@@ -40,7 +37,6 @@ def test_moe_fused_gate_combined(
|
||||
topk=topk,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=2.5,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
ref_output, ref_indices = biased_grouped_topk(
|
||||
scores,
|
||||
@@ -52,7 +48,6 @@ def test_moe_fused_gate_combined(
|
||||
topk_group=topk_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=2.5,
|
||||
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
|
||||
|
||||
Reference in New Issue
Block a user