Add deepseek style fused moe group gate selection kernel (#4530)
This commit is contained in:
@@ -151,6 +151,7 @@ set(SOURCES
|
||||
"csrc/gemm/per_token_group_quant_8bit.cu"
|
||||
"csrc/gemm/per_token_quant_fp8.cu"
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
|
||||
74
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
74
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
|
||||
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
|
||||
return biased_grouped_topk(
|
||||
scores,
|
||||
scores,
|
||||
bias,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk):
|
||||
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
|
||||
|
||||
|
||||
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
|
||||
configs = [(sq,) for sq in seq_length_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["seq_length"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["original", "kernel"],
|
||||
line_names=["Original", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="moe-fused-gate-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(seq_length, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||
|
||||
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
|
||||
bias = torch.rand(num_experts, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org_kernel(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
447
sgl-kernel/csrc/moe/moe_fused_gate.cu
Normal file
447
sgl-kernel/csrc/moe/moe_fused_gate.cu
Normal file
@@ -0,0 +1,447 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <stdio.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <type_traits>
|
||||
template <typename T, int N>
|
||||
using AlignedArray = cutlass::AlignedArray<T, N>;
|
||||
using bfloat16_t = cutlass::bfloat16_t;
|
||||
using float16_t = cutlass::half_t;
|
||||
using float32_t = float;
|
||||
|
||||
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
|
||||
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
|
||||
template <typename T>
|
||||
__device__ inline bool cmp_gt(const T& a, const T& b) {
|
||||
if constexpr (std::is_same<T, at::Half>::value) {
|
||||
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
} else {
|
||||
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
|
||||
return a > b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline bool cmp_eq(const T& a, const T& b) {
|
||||
if constexpr (std::is_same<T, at::Half>::value) {
|
||||
return static_cast<float>(a) == static_cast<float>(b);
|
||||
} else {
|
||||
return a == b;
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed constants common to both dynamic and static template versions:
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int WARPS_PER_CTA = 6;
|
||||
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
|
||||
|
||||
// Create an alias for Array using AlignedArray
|
||||
template <typename T, int N>
|
||||
using Array = AlignedArray<T, N>;
|
||||
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
|
||||
template <typename T>
|
||||
using AccessType = AlignedArray<T, MAX_VPT>;
|
||||
|
||||
template <typename T, typename Params>
|
||||
__device__ void moe_fused_gate_impl(
|
||||
void* input,
|
||||
void* bias,
|
||||
float* output_ptr,
|
||||
int32_t* indices_ptr,
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
Params params) {
|
||||
int tidx = threadIdx.x;
|
||||
int64_t thread_row =
|
||||
blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW;
|
||||
if (thread_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Cast pointers to type T:
|
||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||
auto* bias_ptr = reinterpret_cast<T*>(bias);
|
||||
auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS;
|
||||
|
||||
int thread_group_idx = tidx % params.THREADS_PER_ROW;
|
||||
int first_elt_read_by_thread = thread_group_idx * params.VPT;
|
||||
|
||||
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
|
||||
// AccessType.
|
||||
T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
Array<T, MAX_VPT> row_chunk;
|
||||
AccessType<T> const* vec_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(thread_read_ptr);
|
||||
|
||||
T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread;
|
||||
Array<T, MAX_VPT> bias_chunk;
|
||||
AccessType<T> const* vec_bias_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(bias_thread_read_ptr);
|
||||
|
||||
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
|
||||
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
|
||||
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
|
||||
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
row_chunk[ii] = vec_thread_read_ptr[0][ii];
|
||||
bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Sigmoid //////////////////////
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
row_chunk[ii] = static_cast<T>(1.0f / (1.0f + expf(-float(row_chunk[ii]))));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Add Bias //////////////////////
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii];
|
||||
}
|
||||
|
||||
////////////////////// Exclude Groups //////////////////////
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group;
|
||||
++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
|
||||
int expert = first_elt_read_by_thread;
|
||||
// local argmax
|
||||
T max_val = static_cast<T>(-FLT_MAX);
|
||||
T max_val_second = static_cast<T>(-FLT_MAX);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
T val = bias_chunk[ii];
|
||||
|
||||
if (cmp_gt(val, max_val)) {
|
||||
max_val_second = max_val;
|
||||
max_val = val;
|
||||
} else if (cmp_gt(val, max_val_second)) {
|
||||
max_val_second = val;
|
||||
}
|
||||
}
|
||||
|
||||
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
|
||||
// to select expert groups
|
||||
T max_sum = max_val + max_val_second;
|
||||
|
||||
// argmin reduce
|
||||
#pragma unroll
|
||||
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
T other_max_sum =
|
||||
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_sum), mask, params.THREADS_PER_ROW));
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
|
||||
|
||||
// higher indices win
|
||||
if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) {
|
||||
max_sum = other_max_sum;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// clear the max value in the thread
|
||||
if (k_idx < params.THREADS_PER_ROW - topk_group) {
|
||||
int const thread_to_clear_in_group = expert / params.VPT;
|
||||
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.VPT; ++ii) {
|
||||
bias_chunk[ii] = static_cast<T>(FLT_MAX);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
////////////////////// Topk //////////////////////
|
||||
float output_sum = 0.0f;
|
||||
for (int k_idx = 0; k_idx < topk; ++k_idx) {
|
||||
// local argmax
|
||||
T max_val = bias_chunk[0];
|
||||
int expert = first_elt_read_by_thread;
|
||||
|
||||
if (!cmp_eq(max_val, static_cast<T>(FLT_MAX))) {
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < params.VPT; ++ii) {
|
||||
T val = bias_chunk[ii];
|
||||
if (cmp_gt(val, max_val)) {
|
||||
max_val = val;
|
||||
expert = first_elt_read_by_thread + ii;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
max_val = static_cast<T>(-FLT_MAX);
|
||||
}
|
||||
|
||||
// argmax reduce
|
||||
#pragma unroll
|
||||
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
T other_max =
|
||||
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_val), mask, params.THREADS_PER_ROW));
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
|
||||
|
||||
// lower indices to win
|
||||
if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
if (k_idx < topk) {
|
||||
int thread_to_clear_in_group = expert / params.VPT;
|
||||
int64_t idx = topk * thread_row + k_idx;
|
||||
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
int expert_to_clear_in_thread = expert % params.VPT;
|
||||
|
||||
// clear the max value in the thread
|
||||
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
|
||||
|
||||
// store output
|
||||
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
|
||||
indices_ptr[idx] = static_cast<int32_t>(expert);
|
||||
}
|
||||
|
||||
// accumulate sum
|
||||
if (thread_group_idx == 0) {
|
||||
output_sum += output_ptr[idx];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
////////////////////// Rescale Output //////////////////////
|
||||
if (thread_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < topk; ++ii) {
|
||||
int64_t const idx = topk * thread_row + ii;
|
||||
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Templated Kernel Version (using compile-time constants)
|
||||
//------------------------------------------------------------------------------
|
||||
template <int VPT_, int NUM_EXPERTS_, int THREADS_PER_ROW_, int ROWS_PER_WARP_, int ROWS_PER_CTA_, int WARPS_PER_CTA_>
|
||||
struct KernelParams {
|
||||
static constexpr int VPT = VPT_;
|
||||
static constexpr int NUM_EXPERTS = NUM_EXPERTS_;
|
||||
static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_;
|
||||
static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_;
|
||||
static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_;
|
||||
static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int VPT,
|
||||
int NUM_EXPERTS,
|
||||
int THREADS_PER_ROW,
|
||||
int ROWS_PER_WARP,
|
||||
int ROWS_PER_CTA,
|
||||
int WARPS_PER_CTA>
|
||||
__global__ void moe_fused_gate_kernel(
|
||||
void* input,
|
||||
void* bias,
|
||||
float* output_ptr,
|
||||
int32_t* indices_ptr,
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk) {
|
||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
|
||||
}
|
||||
|
||||
// Macro to compute compile-time constants and launch the kernel.
|
||||
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
|
||||
do { \
|
||||
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
|
||||
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
|
||||
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
|
||||
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
|
||||
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
|
||||
<<<num_blocks, block_dim, 0, stream>>>( \
|
||||
input.data_ptr(), \
|
||||
bias.data_ptr(), \
|
||||
output.data_ptr<float>(), \
|
||||
indices.data_ptr<int32_t>(), \
|
||||
num_rows, \
|
||||
topk_group, \
|
||||
topk); \
|
||||
dispatched = true; \
|
||||
} while (0)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Dynamic Kernel Version (parameters computed at runtime)
|
||||
//------------------------------------------------------------------------------
|
||||
struct KernelParamsDynamic {
|
||||
int VPT;
|
||||
int NUM_EXPERTS;
|
||||
int THREADS_PER_ROW;
|
||||
int ROWS_PER_WARP;
|
||||
int ROWS_PER_CTA;
|
||||
int WARPS_PER_CTA;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__global__ void moe_fused_gate_kernel_dynamic(
|
||||
void* input,
|
||||
void* bias,
|
||||
float* output_ptr,
|
||||
int32_t* indices_ptr,
|
||||
int64_t num_rows,
|
||||
int64_t num_experts,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk) {
|
||||
KernelParamsDynamic params;
|
||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
|
||||
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
|
||||
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
|
||||
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
|
||||
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
|
||||
|
||||
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Host Launcher Function
|
||||
//------------------------------------------------------------------------------
|
||||
std::vector<at::Tensor>
|
||||
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
|
||||
int64_t num_rows = input.size(0);
|
||||
int32_t num_experts = input.size(1);
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto output = torch::empty({num_rows, topk}, options);
|
||||
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
|
||||
|
||||
// Compute grid dimensions based on runtime value for num_expert_group.
|
||||
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
|
||||
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
|
||||
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
|
||||
|
||||
// Check 1: Ensure that num_experts is a power of 2.
|
||||
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
|
||||
|
||||
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
|
||||
TORCH_CHECK(
|
||||
num_experts % num_expert_group == 0,
|
||||
"num_experts must be divisible by num_expert_group, but got ",
|
||||
num_experts,
|
||||
" / ",
|
||||
num_expert_group);
|
||||
|
||||
int computed_vpt = num_experts / num_expert_group;
|
||||
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
|
||||
// threads we can process.
|
||||
TORCH_CHECK(
|
||||
computed_vpt <= MAX_VPT,
|
||||
"Per group experts: num_experts / num_expert_group = (",
|
||||
computed_vpt,
|
||||
") exceeds the maximum supported (",
|
||||
MAX_VPT,
|
||||
")");
|
||||
|
||||
// Dispatch to templated kernel for known compile-time configurations.
|
||||
// We currently only support for:
|
||||
// Case 1: 256 experts, with 8 or 16 groups.
|
||||
// Case 2: 128 experts, with 4 or 8 groups.
|
||||
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
|
||||
bool dispatched = false;
|
||||
switch (num_experts) {
|
||||
case 256:
|
||||
if (num_expert_group == 8)
|
||||
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 8);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8);
|
||||
} else if (num_expert_group == 16)
|
||||
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16);
|
||||
}
|
||||
break;
|
||||
case 128:
|
||||
if (num_expert_group == 4)
|
||||
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 4);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4);
|
||||
} else if (num_expert_group == 8)
|
||||
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (!dispatched) {
|
||||
// Fallback to the dynamic kernel if none of the supported combinations match.
|
||||
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
|
||||
if (input.scalar_type() == at::kBFloat16) {
|
||||
moe_fused_gate_kernel_dynamic<bfloat16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
bias.data_ptr(),
|
||||
output.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
num_rows,
|
||||
num_experts,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
bias.data_ptr(),
|
||||
output.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
num_rows,
|
||||
num_experts,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input.data_ptr(),
|
||||
bias.data_ptr(),
|
||||
output.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
num_rows,
|
||||
num_experts,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||
}
|
||||
}
|
||||
return {output, indices};
|
||||
}
|
||||
@@ -138,6 +138,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
m.def(
|
||||
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
@@ -199,6 +199,9 @@ void topk_softmax(
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
@@ -36,7 +36,7 @@ from sgl_kernel.gemm import (
|
||||
sgl_per_token_group_quant_int8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel.moe import moe_align_block_size, topk_softmax
|
||||
from sgl_kernel.moe import moe_align_block_size, moe_fused_gate, topk_softmax
|
||||
from sgl_kernel.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
|
||||
@@ -32,3 +32,15 @@ def topk_softmax(
|
||||
torch.ops.sgl_kernel.topk_softmax.default(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output
|
||||
)
|
||||
|
||||
|
||||
def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk):
|
||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
||||
# as the group weight to select exerpt groups and then select topk experts within the selected groups
|
||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
|
||||
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
return torch.ops.sgl_kernel.moe_fused_gate(
|
||||
input_tensor, bias, num_expert_group, topk_group, topk
|
||||
)
|
||||
|
||||
@@ -161,6 +161,7 @@ sources = [
|
||||
"csrc/gemm/per_token_quant_fp8.cu",
|
||||
"csrc/gemm/per_tensor_quant_fp8.cu",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/moe/moe_fused_gate.cu",
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||
"csrc/speculative/eagle_utils.cu",
|
||||
"csrc/speculative/speculative_sampling.cu",
|
||||
|
||||
72
sgl-kernel/tests/test_moe_fused_gate.py
Normal file
72
sgl-kernel/tests/test_moe_fused_gate.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_length",
|
||||
list(range(1, 10))
|
||||
+ [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
(128, 4, 2, 4),
|
||||
(256, 8, 4, 8), # deepseek v3
|
||||
(512, 16, 8, 16),
|
||||
],
|
||||
)
|
||||
def test_moe_fused_gate_combined(seq_length, dtype, params):
|
||||
num_experts, num_expert_group, topk_group, topk = params
|
||||
|
||||
torch.manual_seed(seq_length)
|
||||
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
|
||||
scores = tensor.clone()
|
||||
bias = torch.rand(num_experts).to(dtype).cuda()
|
||||
|
||||
output, indices = moe_fused_gate(
|
||||
tensor,
|
||||
bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
topk=topk,
|
||||
)
|
||||
ref_output, ref_indices = biased_grouped_topk(
|
||||
scores,
|
||||
scores,
|
||||
bias,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
compiled=False,
|
||||
)
|
||||
|
||||
idx_check = torch.allclose(
|
||||
ref_indices.sort()[0].to(torch.int32),
|
||||
indices.sort()[0].to(torch.int32),
|
||||
rtol=1e-04,
|
||||
atol=1e-05,
|
||||
)
|
||||
output_check = torch.allclose(
|
||||
ref_output.sort()[0].to(torch.float32),
|
||||
output.sort()[0].to(torch.float32),
|
||||
rtol=1e-04,
|
||||
atol=1e-05,
|
||||
)
|
||||
|
||||
assert idx_check, (
|
||||
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||
f"params {params}"
|
||||
)
|
||||
assert output_check, (
|
||||
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||
f"params {params}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user