[sgl-kernel] Add cuda kernel for moe_ep_silu_and_mul (#6919)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -237,6 +237,7 @@ set(SOURCES
|
|||||||
"csrc/moe/fp8_blockwise_moe_kernel.cu"
|
"csrc/moe/fp8_blockwise_moe_kernel.cu"
|
||||||
"csrc/moe/prepare_moe_input.cu"
|
"csrc/moe/prepare_moe_input.cu"
|
||||||
"csrc/moe/ep_moe_reorder_kernel.cu"
|
"csrc/moe/ep_moe_reorder_kernel.cu"
|
||||||
|
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
||||||
"csrc/speculative/eagle_utils.cu"
|
"csrc/speculative/eagle_utils.cu"
|
||||||
"csrc/speculative/speculative_sampling.cu"
|
"csrc/speculative/speculative_sampling.cu"
|
||||||
"csrc/speculative/packbit.cu"
|
"csrc/speculative/packbit.cu"
|
||||||
|
|||||||
92
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Normal file
92
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from sgl_kernel import ep_moe_silu_and_mul
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
|
||||||
|
|
||||||
|
batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||||
|
hidden_size_range = [1024, 2048, 4096, 8192]
|
||||||
|
block_size_range = [128, 256, 512]
|
||||||
|
configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "hidden_size", "block_size"],
|
||||||
|
x_vals=[list(cfg) for cfg in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["cuda", "triton"],
|
||||||
|
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||||
|
styles=[("green", "-"), ("orange", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="ep-moe-silu-and-mul-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, hidden_size, block_size, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
half_hidden_size = hidden_size // 2
|
||||||
|
start_expert_id, end_expert_id = 0, 255
|
||||||
|
block_size = 512
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
def alloc_tensors():
|
||||||
|
gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
down_input = torch.empty(
|
||||||
|
batch_size, half_hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
reorder_topk_ids = torch.randint(
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id + 1,
|
||||||
|
(batch_size,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
scales = torch.rand(
|
||||||
|
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
return gateup_output, down_input, reorder_topk_ids, scales
|
||||||
|
|
||||||
|
if provider == "cuda":
|
||||||
|
gateup, down, ids, scales = alloc_tensors()
|
||||||
|
|
||||||
|
def run_cuda():
|
||||||
|
ep_moe_silu_and_mul(
|
||||||
|
gateup,
|
||||||
|
down,
|
||||||
|
ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||||
|
|
||||||
|
elif provider == "triton":
|
||||||
|
gateup, down, ids, scales = alloc_tensors()
|
||||||
|
|
||||||
|
def run_triton():
|
||||||
|
silu_and_mul_triton_kernel[(batch_size,)](
|
||||||
|
gateup.view(-1),
|
||||||
|
down.view(-1),
|
||||||
|
hidden_size,
|
||||||
|
ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark.run(print_data=True)
|
||||||
@@ -177,6 +177,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
|
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
|
||||||
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
|
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
|
||||||
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
|
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
|
||||||
|
m.def(
|
||||||
|
"ep_moe_silu_and_mul(Tensor gateup_output, Tensor down_input, Tensor reorder_topk_ids, Tensor scales, int "
|
||||||
|
"start_expert_id, int end_expert_id) -> ()");
|
||||||
|
m.impl("ep_moe_silu_and_mul", torch::kCUDA, &ep_moe_silu_and_mul);
|
||||||
m.def(
|
m.def(
|
||||||
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
|
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
|
||||||
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
|
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
|
||||||
|
|||||||
115
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
Normal file
115
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include <THC/THCAtomics.cuh>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <flashinfer/vec_dtypes.cuh>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
using namespace flashinfer;
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ inline scalar_t silu_quantize(float x);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline float silu_quantize<float>(float x) {
|
||||||
|
float y = x / (1.f + __expf(-x));
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline __half silu_quantize<__half>(float x) {
|
||||||
|
float y = x / (1.f + __expf(-x));
|
||||||
|
return __float2half_rn(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline __nv_bfloat16 silu_quantize<__nv_bfloat16>(float x) {
|
||||||
|
float y = x / (1.f + __expf(-x));
|
||||||
|
return __float2bfloat16_rn(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void ep_moe_act_and_mul_cuda_kernel(
|
||||||
|
const scalar_t* __restrict__ gateup_output,
|
||||||
|
scalar_t* __restrict__ down_input,
|
||||||
|
const int* __restrict__ reorder_topk_ids,
|
||||||
|
const float* __restrict__ scales,
|
||||||
|
int start_expert_id,
|
||||||
|
int end_expert_id,
|
||||||
|
int hidden_size) {
|
||||||
|
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||||
|
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||||
|
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
const int64_t thread_idx = threadIdx.x;
|
||||||
|
const int64_t stride = blockDim.x;
|
||||||
|
|
||||||
|
const int half_hidden_size = hidden_size >> 1;
|
||||||
|
const int expert_id = reorder_topk_ids[token_idx];
|
||||||
|
|
||||||
|
if (expert_id < start_expert_id || expert_id > end_expert_id) return;
|
||||||
|
const scalar_t* gate_output_ptr = gateup_output + static_cast<int64_t>(token_idx) * hidden_size;
|
||||||
|
const scalar_t* up_output_ptr = gate_output_ptr + half_hidden_size;
|
||||||
|
scalar_t* dst_ptr = down_input + static_cast<int64_t>(token_idx) * half_hidden_size;
|
||||||
|
scalar_t scale_q = static_cast<scalar_t>(scales ? (1.f / scales[expert_id - start_expert_id]) : 1.f);
|
||||||
|
|
||||||
|
const uint32_t vec_elements = half_hidden_size / vec_size;
|
||||||
|
#pragma unroll 1
|
||||||
|
for (uint32_t idx = thread_idx; idx < vec_elements; idx += stride) {
|
||||||
|
vec_t gate_vec, up_vec, out_vec;
|
||||||
|
gate_vec.load(gate_output_ptr + idx * vec_size);
|
||||||
|
up_vec.load(up_output_ptr + idx * vec_size);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||||
|
float gate_f = static_cast<float>(gate_vec[i]);
|
||||||
|
scalar_t gate_q = silu_quantize<scalar_t>(gate_f);
|
||||||
|
scalar_t prod = gate_q * up_vec[i] * scale_q;
|
||||||
|
out_vec[i] = prod;
|
||||||
|
}
|
||||||
|
out_vec.store(dst_ptr + idx * vec_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t scalar_start = static_cast<int64_t>(vec_elements) * vec_size + thread_idx;
|
||||||
|
#pragma unroll 1
|
||||||
|
for (int64_t idx = scalar_start; idx < half_hidden_size; idx += stride) {
|
||||||
|
float gate_f = static_cast<float>(gate_output_ptr[idx]);
|
||||||
|
scalar_t gate_q = silu_quantize<scalar_t>(gate_f);
|
||||||
|
dst_ptr[idx] = gate_q * up_output_ptr[idx] * scale_q;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ep_moe_silu_and_mul(
|
||||||
|
torch::Tensor gateup_output,
|
||||||
|
torch::Tensor down_input,
|
||||||
|
torch::Tensor reorder_topk_ids,
|
||||||
|
torch::Tensor scales,
|
||||||
|
int64_t start_expert_id,
|
||||||
|
int64_t end_expert_id) {
|
||||||
|
const int total_tokens = gateup_output.size(0);
|
||||||
|
const int hidden_size = gateup_output.size(1);
|
||||||
|
|
||||||
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(gateup_output.scalar_type(), scalar_t, [&] {
|
||||||
|
dim3 grid(total_tokens);
|
||||||
|
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||||
|
const int half_hidden_size = hidden_size >> 1;
|
||||||
|
uint32_t threads = (half_hidden_size + vec_size - 1) / vec_size;
|
||||||
|
threads = std::max<uint32_t>(threads, 256);
|
||||||
|
threads = ((threads + 31) & ~31U);
|
||||||
|
dim3 block(std::min(threads, 1024U));
|
||||||
|
ep_moe_act_and_mul_cuda_kernel<scalar_t><<<grid, block>>>(
|
||||||
|
static_cast<scalar_t*>(gateup_output.data_ptr()),
|
||||||
|
static_cast<scalar_t*>(down_input.data_ptr()),
|
||||||
|
reorder_topk_ids.data_ptr<int>(),
|
||||||
|
scales.defined() ? scales.data_ptr<float>() : nullptr,
|
||||||
|
static_cast<int>(start_expert_id),
|
||||||
|
static_cast<int>(end_expert_id),
|
||||||
|
hidden_size);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -266,6 +266,14 @@ void ep_moe_pre_reorder(
|
|||||||
int64_t topk,
|
int64_t topk,
|
||||||
bool use_per_token_if_dynamic);
|
bool use_per_token_if_dynamic);
|
||||||
|
|
||||||
|
void ep_moe_silu_and_mul(
|
||||||
|
torch::Tensor gateup_output,
|
||||||
|
torch::Tensor down_input,
|
||||||
|
torch::Tensor reorder_topk_ids,
|
||||||
|
torch::Tensor scales,
|
||||||
|
int64_t start_expert_id,
|
||||||
|
int64_t end_expert_id);
|
||||||
|
|
||||||
void ep_moe_post_reorder(
|
void ep_moe_post_reorder(
|
||||||
torch::Tensor down_output,
|
torch::Tensor down_output,
|
||||||
torch::Tensor output,
|
torch::Tensor output,
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from sgl_kernel.moe import (
|
|||||||
cutlass_fp4_group_mm,
|
cutlass_fp4_group_mm,
|
||||||
ep_moe_post_reorder,
|
ep_moe_post_reorder,
|
||||||
ep_moe_pre_reorder,
|
ep_moe_pre_reorder,
|
||||||
|
ep_moe_silu_and_mul,
|
||||||
fp8_blockwise_scaled_grouped_mm,
|
fp8_blockwise_scaled_grouped_mm,
|
||||||
moe_align_block_size,
|
moe_align_block_size,
|
||||||
moe_fused_gate,
|
moe_fused_gate,
|
||||||
|
|||||||
@@ -88,6 +88,24 @@ def ep_moe_pre_reorder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ep_moe_silu_and_mul(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
):
|
||||||
|
return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ep_moe_post_reorder(
|
def ep_moe_post_reorder(
|
||||||
down_output,
|
down_output,
|
||||||
output,
|
output,
|
||||||
|
|||||||
142
sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
Normal file
142
sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import ep_moe_silu_and_mul
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_tensors(
|
||||||
|
total_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
gateup_output = torch.randn(total_tokens, hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
reorder_topk_ids = torch.randint(
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id + 1,
|
||||||
|
(total_tokens,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_experts = end_expert_id - start_expert_id + 1
|
||||||
|
scales = torch.rand(num_experts, dtype=torch.float32, device=device) * 0.8 + 0.5
|
||||||
|
|
||||||
|
half_hidden = hidden_size // 2
|
||||||
|
down_input = torch.empty(total_tokens, half_hidden, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return gateup_output, down_input, reorder_topk_ids, scales
|
||||||
|
|
||||||
|
|
||||||
|
def run_cuda_kernel(
|
||||||
|
gateup_output: torch.Tensor,
|
||||||
|
down_input: torch.Tensor,
|
||||||
|
reorder_topk_ids: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
):
|
||||||
|
ep_moe_silu_and_mul(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
)
|
||||||
|
return down_input
|
||||||
|
|
||||||
|
|
||||||
|
def run_triton_kernel(
|
||||||
|
gateup_output: torch.Tensor,
|
||||||
|
down_input: torch.Tensor,
|
||||||
|
reorder_topk_ids: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
hidden_size: int,
|
||||||
|
):
|
||||||
|
total_tokens = gateup_output.size(0)
|
||||||
|
block_size = 512
|
||||||
|
|
||||||
|
silu_and_mul_triton_kernel[(total_tokens,)](
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
hidden_size,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
return down_input
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"total_tokens,hidden_size",
|
||||||
|
list(itertools.product([32, 256, 1024], [128, 256, 512])),
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
def test_ep_moe_silu_and_mul_vs_triton(
|
||||||
|
total_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
start_expert_id = 0
|
||||||
|
end_expert_id = 15
|
||||||
|
|
||||||
|
(
|
||||||
|
gateup_output,
|
||||||
|
_,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
) = create_test_tensors(
|
||||||
|
total_tokens,
|
||||||
|
hidden_size,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_input_cuda = torch.empty(
|
||||||
|
total_tokens, hidden_size // 2, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
down_input_triton = torch.empty_like(down_input_cuda)
|
||||||
|
|
||||||
|
cuda_output = run_cuda_kernel(
|
||||||
|
gateup_output,
|
||||||
|
down_input_cuda,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_output = run_triton_kernel(
|
||||||
|
gateup_output,
|
||||||
|
down_input_triton,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
cuda_output,
|
||||||
|
triton_output,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user