[EP] Add cuda kernel for moe_ep_post_reorder (#6837)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
92
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Normal file
92
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_post_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
configs = [(bs,) for bs in batch_sizes]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-post-reorder-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
|
||||
|
||||
def alloc_tensors():
|
||||
down_output = torch.randn(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
src2dst = torch.randint(
|
||||
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
|
||||
return down_output, output, src2dst, topk_ids, topk_weights
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "cuda":
|
||||
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_post_reorder(
|
||||
d_out,
|
||||
out,
|
||||
s2d,
|
||||
tk_ids,
|
||||
tk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
post_reorder_triton_kernel[(batch_size,)](
|
||||
d_out.view(-1),
|
||||
out.view(-1),
|
||||
s2d.view(-1),
|
||||
tk_ids.view(-1),
|
||||
tk_weights.view(-1),
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
@@ -174,9 +174,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
"ep_moe_pre_reorder(Tensor input_ptr, Tensor gateup_input_ptr, Tensor src2dst_ptr, Tensor topk_ids_ptr, Tensor "
|
||||
"a1_scales_ptr, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
|
||||
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
|
||||
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
|
||||
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
|
||||
m.def(
|
||||
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
|
||||
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
|
||||
m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder);
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
|
||||
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||
|
||||
@@ -67,6 +67,57 @@ __global__ void ep_pre_reorder_cuda_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ep_post_reorder_cuda_kernel(
|
||||
const scalar_t* __restrict__ down_output_ptr,
|
||||
scalar_t* __restrict__ output_ptr,
|
||||
const int* __restrict__ src2dst_ptr,
|
||||
const int* __restrict__ topk_ids_ptr,
|
||||
const scalar_t* __restrict__ topk_weights_ptr,
|
||||
int start_expert_id,
|
||||
int end_expert_id,
|
||||
int topk,
|
||||
int hidden_size) {
|
||||
const int token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int* token_src2dst = src2dst_ptr + token_idx * topk;
|
||||
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
|
||||
const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk;
|
||||
|
||||
scalar_t* dst_ptr = output_ptr + static_cast<int64_t>(token_idx) * hidden_size;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||
|
||||
const int vec_iters = hidden_size / vec_size;
|
||||
for (int idx = tid; idx < vec_iters; idx += blockDim.x) {
|
||||
float acc[vec_size] = {0};
|
||||
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
const int expert_id = token_topk_ids[k];
|
||||
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
|
||||
const int src_row = token_src2dst[k];
|
||||
const scalar_t* src_ptr = down_output_ptr + static_cast<int64_t>(src_row) * hidden_size;
|
||||
const float weight = static_cast<float>(token_topk_weights[k]);
|
||||
|
||||
vec_t src_vec;
|
||||
src_vec.cast_load(src_ptr + idx * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
acc[i] += static_cast<float>(src_vec[i]) * weight;
|
||||
}
|
||||
}
|
||||
vec_t out_vec;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < vec_size; ++i)
|
||||
out_vec[i] = static_cast<scalar_t>(acc[i]);
|
||||
|
||||
out_vec.cast_store(dst_ptr + idx * vec_size);
|
||||
}
|
||||
}
|
||||
|
||||
void ep_moe_pre_reorder(
|
||||
torch::Tensor input,
|
||||
torch::Tensor gateup_input,
|
||||
@@ -77,8 +128,8 @@ void ep_moe_pre_reorder(
|
||||
int64_t end_expert_id,
|
||||
int64_t topk,
|
||||
bool use_per_token_if_dynamic) {
|
||||
int total_blocks = input.size(0);
|
||||
int block_size = 512;
|
||||
const int total_blocks = input.size(0);
|
||||
const int block_size = 512;
|
||||
dim3 grid(total_blocks);
|
||||
dim3 block(block_size);
|
||||
int hidden_size = input.size(1);
|
||||
@@ -98,3 +149,33 @@ void ep_moe_pre_reorder(
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
void ep_moe_post_reorder(
|
||||
torch::Tensor down_output,
|
||||
torch::Tensor output,
|
||||
torch::Tensor src2dst,
|
||||
torch::Tensor topk_ids,
|
||||
torch::Tensor topk_weights,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id,
|
||||
int64_t topk) {
|
||||
const int total_tokens = output.size(0);
|
||||
const int block_size = 512;
|
||||
dim3 grid(total_tokens);
|
||||
dim3 block(block_size);
|
||||
const int hidden_size = output.size(1);
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] {
|
||||
ep_post_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
|
||||
static_cast<scalar_t*>(down_output.data_ptr()),
|
||||
static_cast<scalar_t*>(output.data_ptr()),
|
||||
src2dst.data_ptr<int>(),
|
||||
topk_ids.data_ptr<int>(),
|
||||
static_cast<scalar_t*>(topk_weights.data_ptr()),
|
||||
static_cast<int>(start_expert_id),
|
||||
static_cast<int>(end_expert_id),
|
||||
static_cast<int>(topk),
|
||||
hidden_size);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -264,6 +264,16 @@ void ep_moe_pre_reorder(
|
||||
int64_t topk,
|
||||
bool use_per_token_if_dynamic);
|
||||
|
||||
void ep_moe_post_reorder(
|
||||
torch::Tensor down_output,
|
||||
torch::Tensor output,
|
||||
torch::Tensor src2dst,
|
||||
torch::Tensor topk_ids,
|
||||
torch::Tensor topk_weights,
|
||||
int64_t start_expert_id,
|
||||
int64_t end_expert_id,
|
||||
int64_t topk);
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
|
||||
@@ -49,6 +49,7 @@ from sgl_kernel.gemm import (
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.moe import (
|
||||
cutlass_fp4_group_mm,
|
||||
ep_moe_post_reorder,
|
||||
ep_moe_pre_reorder,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
|
||||
@@ -88,6 +88,28 @@ def ep_moe_pre_reorder(
|
||||
)
|
||||
|
||||
|
||||
def ep_moe_post_reorder(
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
):
|
||||
return torch.ops.sgl_kernel.ep_moe_post_reorder.default(
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_grouped_mm(
|
||||
output,
|
||||
a_ptrs,
|
||||
|
||||
163
sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py
Normal file
163
sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import ep_moe_post_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
down_output = torch.randn(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Ensure src2dst has no duplicate destinations to avoid race conditions
|
||||
total_tokens = batch_size * topk
|
||||
dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32)
|
||||
src2dst = dst_indices.view(batch_size, topk)
|
||||
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
|
||||
|
||||
return down_output, src2dst, topk_ids, topk_weights
|
||||
|
||||
|
||||
def run_cuda_kernel(
|
||||
down_output: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
src2dst: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
topk: int,
|
||||
):
|
||||
ep_moe_post_reorder(
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def run_triton_kernel(
|
||||
down_output: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
src2dst: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
topk: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
batch_size = down_output.size(0)
|
||||
block_size = 512
|
||||
|
||||
post_reorder_triton_kernel[(batch_size,)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
block_size,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
a32, b32 = a.float(), b.float()
|
||||
if a.dtype is torch.float16:
|
||||
torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-2)
|
||||
elif a.dtype is torch.bfloat16:
|
||||
torch.testing.assert_close(a32, b32, rtol=1e-4, atol=1e-1)
|
||||
else:
|
||||
torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,hidden_size,topk",
|
||||
list(itertools.product([32, 64], [128, 256, 512], [2, 4, 8])),
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_ep_moe_post_reorder_vs_triton(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
start_expert_id = 0
|
||||
end_expert_id = 15
|
||||
|
||||
(
|
||||
down_output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
) = create_test_tensors(
|
||||
batch_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
output_cuda = torch.empty(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
output_triton = torch.empty(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
cuda_output = run_cuda_kernel(
|
||||
down_output,
|
||||
output_cuda,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
triton_output = run_triton_kernel(
|
||||
down_output,
|
||||
output_triton,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
)
|
||||
|
||||
assert_close(cuda_output, triton_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user