[1/2] sgl-kernel: Fuse routed scaling factor into select_experts (#8364)
This commit is contained in:
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def(
|
||||
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
|
||||
"num_fused_shared_experts, float routed_scaling_factor) -> "
|
||||
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
|
||||
@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output,
|
||||
Params params) {
|
||||
int tidx = threadIdx.x;
|
||||
int64_t thread_row =
|
||||
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
|
||||
for (int ii = 0; ii < topk; ++ii) {
|
||||
int64_t const idx = topk * thread_row + ii;
|
||||
output_ptr[idx] = output_ptr[idx] / output_sum;
|
||||
if (apply_routed_scaling_factor_on_output) {
|
||||
output_ptr[idx] *= routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||
moe_fused_gate_impl<T>(
|
||||
input,
|
||||
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
params);
|
||||
}
|
||||
|
||||
@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
|
||||
topk_group, \
|
||||
topk, \
|
||||
num_fused_shared_experts, \
|
||||
routed_scaling_factor); \
|
||||
routed_scaling_factor, \
|
||||
apply_routed_scaling_factor_on_output); \
|
||||
dispatched = true; \
|
||||
} while (0)
|
||||
|
||||
@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
KernelParamsDynamic params;
|
||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
|
||||
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
params);
|
||||
}
|
||||
|
||||
@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
double routed_scaling_factor,
|
||||
bool apply_routed_scaling_factor_on_output) {
|
||||
int64_t num_rows = input.size(0);
|
||||
int32_t num_experts = input.size(1);
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user