Sgl kernel fused_moe_gate support n_shared_experts (#5440)

This commit is contained in:
Xiaoyu Zhang
2025-04-18 14:05:15 +08:00
committed by GitHub
parent 53dcf38876
commit 8e09b37077
5 changed files with 140 additions and 38 deletions

View File

@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
return;
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
auto* bias_ptr = reinterpret_cast<T*>(bias);
@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
////////////////////// Topk //////////////////////
float output_sum = 0.0f;
for (int k_idx = 0; k_idx < topk; ++k_idx) {
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
// local argmax
T max_val = bias_chunk[0];
int expert = first_elt_read_by_thread;
@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
max_val = static_cast<T>(-FLT_MAX);
}
// argmax reduce
// argmax reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max =
@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl(
}
}
if (k_idx < topk) {
int thread_to_clear_in_group = expert / params.VPT;
int64_t idx = topk * thread_row + k_idx;
int thread_to_clear_in_group = expert / params.VPT;
int64_t idx = topk * thread_row + k_idx;
if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT;
if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT;
// clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert);
}
// store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert);
}
// accumulate sum
if (thread_group_idx == 0) {
output_sum += output_ptr[idx];
}
// accumulate sum for all elements
if (thread_group_idx == 0) {
output_sum += output_ptr[idx];
}
__syncthreads();
}
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
__syncthreads();
////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) {
#pragma unroll
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum));
output_ptr[idx] = output_ptr[idx] / output_sum;
}
}
}
@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk) {
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}
// Macro to compute compile-time constants and launch the kernel.
@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk); \
topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \
} while (0)
@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t num_experts,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk) {
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std::vector<at::Tensor>
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts,
num_expert_group,
topk_group,
topk);
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts,
num_expert_group,
topk_group,
topk);
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts,
num_expert_group,
topk_group,
topk);
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}