Revert "[1/2] sgl-kernel: Fuse routed scaling factor into select_experts" (#8706)

This commit is contained in:
Liangsheng Yin
2025-08-02 20:14:30 +08:00
committed by GitHub
parent ac6962ccd6
commit f9f0138f80
5 changed files with 12 additions and 38 deletions

View File

@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum;
if (apply_routed_scaling_factor_on_output) {
output_ptr[idx] *= routed_scaling_factor;
}
}
}
}
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
input,
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}
@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
topk_group, \
topk, \
num_fused_shared_experts, \
routed_scaling_factor, \
apply_routed_scaling_factor_on_output); \
routed_scaling_factor); \
dispatched = true; \
} while (0)
@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
double routed_scaling_factor) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}
@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
double routed_scaling_factor) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}