Sgl kernel fused_moe_gate support n_shared_experts (#5440)
This commit is contained in:
@@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
|
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
|
||||||
|
"n_share_experts_fusion, float routed_scaling_factor) -> "
|
||||||
"(Tensor[])");
|
"(Tensor[])");
|
||||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
|
|||||||
int64_t num_rows,
|
int64_t num_rows,
|
||||||
int64_t topk_group,
|
int64_t topk_group,
|
||||||
int64_t topk,
|
int64_t topk,
|
||||||
|
int64_t n_share_experts_fusion,
|
||||||
|
double routed_scaling_factor,
|
||||||
Params params) {
|
Params params) {
|
||||||
int tidx = threadIdx.x;
|
int tidx = threadIdx.x;
|
||||||
int64_t thread_row =
|
int64_t thread_row =
|
||||||
@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate topk_excluding_share_expert_fusion from topk
|
||||||
|
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
|
||||||
|
|
||||||
// Cast pointers to type T:
|
// Cast pointers to type T:
|
||||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||||
auto* bias_ptr = reinterpret_cast<T*>(bias);
|
auto* bias_ptr = reinterpret_cast<T*>(bias);
|
||||||
@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
|
|||||||
|
|
||||||
////////////////////// Topk //////////////////////
|
////////////////////// Topk //////////////////////
|
||||||
float output_sum = 0.0f;
|
float output_sum = 0.0f;
|
||||||
for (int k_idx = 0; k_idx < topk; ++k_idx) {
|
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
|
||||||
// local argmax
|
// local argmax
|
||||||
T max_val = bias_chunk[0];
|
T max_val = bias_chunk[0];
|
||||||
int expert = first_elt_read_by_thread;
|
int expert = first_elt_read_by_thread;
|
||||||
@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
|
|||||||
max_val = static_cast<T>(-FLT_MAX);
|
max_val = static_cast<T>(-FLT_MAX);
|
||||||
}
|
}
|
||||||
|
|
||||||
// argmax reduce
|
// argmax reduce
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||||
T other_max =
|
T other_max =
|
||||||
@@ -195,7 +200,6 @@ __device__ void moe_fused_gate_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (k_idx < topk) {
|
|
||||||
int thread_to_clear_in_group = expert / params.VPT;
|
int thread_to_clear_in_group = expert / params.VPT;
|
||||||
int64_t idx = topk * thread_row + k_idx;
|
int64_t idx = topk * thread_row + k_idx;
|
||||||
|
|
||||||
@@ -210,21 +214,32 @@ __device__ void moe_fused_gate_impl(
|
|||||||
indices_ptr[idx] = static_cast<int32_t>(expert);
|
indices_ptr[idx] = static_cast<int32_t>(expert);
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate sum
|
// accumulate sum for all elements
|
||||||
if (thread_group_idx == 0) {
|
if (thread_group_idx == 0) {
|
||||||
output_sum += output_ptr[idx];
|
output_sum += output_ptr[idx];
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
|
||||||
|
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
||||||
|
|
||||||
|
// Use round-robin to select expert
|
||||||
|
int64_t expert_offset = thread_row % n_share_experts_fusion;
|
||||||
|
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||||
|
|
||||||
|
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||||
|
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
////////////////////// Rescale Output //////////////////////
|
////////////////////// Rescale Output //////////////////////
|
||||||
if (thread_group_idx == 0) {
|
if (thread_group_idx == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < topk; ++ii) {
|
for (int ii = 0; ii < topk; ++ii) {
|
||||||
int64_t const idx = topk * thread_row + ii;
|
int64_t const idx = topk * thread_row + ii;
|
||||||
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum));
|
output_ptr[idx] = output_ptr[idx] / output_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
|
|||||||
int32_t* indices_ptr,
|
int32_t* indices_ptr,
|
||||||
int64_t num_rows,
|
int64_t num_rows,
|
||||||
int64_t topk_group,
|
int64_t topk_group,
|
||||||
int64_t topk) {
|
int64_t topk,
|
||||||
|
int64_t n_share_experts_fusion,
|
||||||
|
double routed_scaling_factor) {
|
||||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||||
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
|
moe_fused_gate_impl<T>(
|
||||||
|
input,
|
||||||
|
bias,
|
||||||
|
output_ptr,
|
||||||
|
indices_ptr,
|
||||||
|
num_rows,
|
||||||
|
topk_group,
|
||||||
|
topk,
|
||||||
|
n_share_experts_fusion,
|
||||||
|
routed_scaling_factor,
|
||||||
|
params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Macro to compute compile-time constants and launch the kernel.
|
// Macro to compute compile-time constants and launch the kernel.
|
||||||
@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
|
|||||||
indices.data_ptr<int32_t>(), \
|
indices.data_ptr<int32_t>(), \
|
||||||
num_rows, \
|
num_rows, \
|
||||||
topk_group, \
|
topk_group, \
|
||||||
topk); \
|
topk, \
|
||||||
|
n_share_experts_fusion, \
|
||||||
|
routed_scaling_factor); \
|
||||||
dispatched = true; \
|
dispatched = true; \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
|||||||
int64_t num_experts,
|
int64_t num_experts,
|
||||||
int64_t num_expert_group,
|
int64_t num_expert_group,
|
||||||
int64_t topk_group,
|
int64_t topk_group,
|
||||||
int64_t topk) {
|
int64_t topk,
|
||||||
|
int64_t n_share_experts_fusion,
|
||||||
|
double routed_scaling_factor) {
|
||||||
KernelParamsDynamic params;
|
KernelParamsDynamic params;
|
||||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||||
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
|
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
|
||||||
@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
|||||||
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
|
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
|
||||||
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
|
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
|
||||||
|
|
||||||
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
|
moe_fused_gate_impl<T>(
|
||||||
|
input,
|
||||||
|
bias,
|
||||||
|
output_ptr,
|
||||||
|
indices_ptr,
|
||||||
|
num_rows,
|
||||||
|
topk_group,
|
||||||
|
topk,
|
||||||
|
n_share_experts_fusion,
|
||||||
|
routed_scaling_factor,
|
||||||
|
params);
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
// Host Launcher Function
|
// Host Launcher Function
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
std::vector<at::Tensor>
|
std::vector<at::Tensor> moe_fused_gate(
|
||||||
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
|
at::Tensor& input,
|
||||||
|
at::Tensor& bias,
|
||||||
|
int64_t num_expert_group,
|
||||||
|
int64_t topk_group,
|
||||||
|
int64_t topk,
|
||||||
|
int64_t n_share_experts_fusion,
|
||||||
|
double routed_scaling_factor) {
|
||||||
int64_t num_rows = input.size(0);
|
int64_t num_rows = input.size(0);
|
||||||
int32_t num_experts = input.size(1);
|
int32_t num_experts = input.size(1);
|
||||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||||
@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
|
|||||||
num_experts,
|
num_experts,
|
||||||
num_expert_group,
|
num_expert_group,
|
||||||
topk_group,
|
topk_group,
|
||||||
topk);
|
topk,
|
||||||
|
n_share_experts_fusion,
|
||||||
|
routed_scaling_factor);
|
||||||
} else if (input.scalar_type() == at::kHalf) {
|
} else if (input.scalar_type() == at::kHalf) {
|
||||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||||
input.data_ptr(),
|
input.data_ptr(),
|
||||||
@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
|
|||||||
num_experts,
|
num_experts,
|
||||||
num_expert_group,
|
num_expert_group,
|
||||||
topk_group,
|
topk_group,
|
||||||
topk);
|
topk,
|
||||||
|
n_share_experts_fusion,
|
||||||
|
routed_scaling_factor);
|
||||||
} else if (input.scalar_type() == at::kFloat) {
|
} else if (input.scalar_type() == at::kFloat) {
|
||||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||||
input.data_ptr(),
|
input.data_ptr(),
|
||||||
@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
|
|||||||
num_experts,
|
num_experts,
|
||||||
num_expert_group,
|
num_expert_group,
|
||||||
topk_group,
|
topk_group,
|
||||||
topk);
|
topk,
|
||||||
|
n_share_experts_fusion,
|
||||||
|
routed_scaling_factor);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,8 +200,14 @@ void topk_softmax(
|
|||||||
torch::Tensor& token_expert_indices,
|
torch::Tensor& token_expert_indices,
|
||||||
torch::Tensor& gating_output);
|
torch::Tensor& gating_output);
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
std::vector<at::Tensor> moe_fused_gate(
|
||||||
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk);
|
at::Tensor& input,
|
||||||
|
at::Tensor& bias,
|
||||||
|
int64_t num_expert_group,
|
||||||
|
int64_t topk_group,
|
||||||
|
int64_t topk,
|
||||||
|
int64_t n_share_experts_fusion,
|
||||||
|
double routed_scaling_factor);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/speculative
|
* From csrc/speculative
|
||||||
|
|||||||
@@ -34,13 +34,29 @@ def topk_softmax(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk):
|
def moe_fused_gate(
|
||||||
|
input_tensor,
|
||||||
|
bias,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
topk,
|
||||||
|
n_share_experts_fusion=0,
|
||||||
|
routed_scaling_factor=0,
|
||||||
|
):
|
||||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||||
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
||||||
# as the group weight to select exerpt groups and then select topk experts within the selected groups
|
# as the group weight to select exerpt groups and then select topk experts within the selected groups
|
||||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
|
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
|
||||||
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||||
|
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
|
||||||
|
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
|
||||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||||
input_tensor, bias, num_expert_group, topk_group, topk
|
input_tensor,
|
||||||
|
bias,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
topk,
|
||||||
|
n_share_experts_fusion,
|
||||||
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,13 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
|
|||||||
(512, 16, 8, 16),
|
(512, 16, 8, 16),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_moe_fused_gate_combined(seq_length, dtype, params):
|
@pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16])
|
||||||
|
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion):
|
||||||
num_experts, num_expert_group, topk_group, topk = params
|
num_experts, num_expert_group, topk_group, topk = params
|
||||||
|
|
||||||
torch.manual_seed(seq_length)
|
torch.manual_seed(seq_length)
|
||||||
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
|
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
|
||||||
scores = tensor.clone()
|
scores = tensor.clone()
|
||||||
bias = torch.rand(num_experts).to(dtype).cuda()
|
bias = torch.rand(num_experts).to(dtype).cuda()
|
||||||
|
topk = topk + min(1, n_share_experts_fusion)
|
||||||
|
|
||||||
output, indices = moe_fused_gate(
|
output, indices = moe_fused_gate(
|
||||||
tensor,
|
tensor,
|
||||||
@@ -33,6 +35,8 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
topk=topk,
|
topk=topk,
|
||||||
|
n_share_experts_fusion=n_share_experts_fusion,
|
||||||
|
routed_scaling_factor=2.5,
|
||||||
)
|
)
|
||||||
ref_output, ref_indices = biased_grouped_topk(
|
ref_output, ref_indices = biased_grouped_topk(
|
||||||
scores,
|
scores,
|
||||||
@@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
compiled=False,
|
compiled=False,
|
||||||
|
n_share_experts_fusion=n_share_experts_fusion,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
|
||||||
|
if n_share_experts_fusion > 0:
|
||||||
|
original_indices = indices.clone()
|
||||||
|
original_ref_indices = ref_indices.clone()
|
||||||
|
|
||||||
|
indices = indices[:, :-1]
|
||||||
|
ref_indices = ref_indices[:, :-1]
|
||||||
|
|
||||||
|
valid_min = num_experts
|
||||||
|
valid_max = num_experts + n_share_experts_fusion
|
||||||
|
shared_indices = original_indices[:, -1]
|
||||||
|
shared_ref_indices = original_ref_indices[:, -1]
|
||||||
|
if shared_indices is not None:
|
||||||
|
assert torch.all(
|
||||||
|
(shared_indices >= valid_min) & (shared_indices < valid_max)
|
||||||
|
), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
|
||||||
|
if shared_ref_indices is not None:
|
||||||
|
assert torch.all(
|
||||||
|
(shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
|
||||||
|
), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"
|
||||||
|
|
||||||
idx_check = torch.allclose(
|
idx_check = torch.allclose(
|
||||||
ref_indices.sort()[0].to(torch.int32),
|
ref_indices.sort()[0].to(torch.int32),
|
||||||
indices.sort()[0].to(torch.int32),
|
indices.sort()[0].to(torch.int32),
|
||||||
@@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
|
|||||||
output_check = torch.allclose(
|
output_check = torch.allclose(
|
||||||
ref_output.sort()[0].to(torch.float32),
|
ref_output.sort()[0].to(torch.float32),
|
||||||
output.sort()[0].to(torch.float32),
|
output.sort()[0].to(torch.float32),
|
||||||
rtol=1e-04,
|
rtol=1e-02,
|
||||||
atol=1e-05,
|
atol=1e-03,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert idx_check, (
|
assert idx_check, (
|
||||||
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
|
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||||
f"params {params}"
|
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
|
||||||
)
|
)
|
||||||
assert output_check, (
|
assert output_check, (
|
||||||
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
|
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||||
f"params {params}"
|
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user