fix ep_moe_reorder kernel bugs (#6858)
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
This commit is contained in:
@@ -1,8 +1,5 @@
|
|||||||
import itertools
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
from sgl_kernel import ep_moe_pre_reorder
|
from sgl_kernel import ep_moe_pre_reorder
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
||||||
@@ -25,9 +22,15 @@ configs = [(bs,) for bs in batch_sizes]
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
def benchmark(batch_size, provider):
|
def benchmark(batch_size, provider):
|
||||||
dtype = torch.float32
|
dtype = torch.bfloat16
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
|
hidden_size, topk, start_expert_id, end_expert_id, block_size = (
|
||||||
|
4096,
|
||||||
|
8,
|
||||||
|
0,
|
||||||
|
255,
|
||||||
|
512,
|
||||||
|
)
|
||||||
|
|
||||||
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
|
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
|
||||||
def alloc_tensors():
|
def alloc_tensors():
|
||||||
@@ -53,9 +56,9 @@ def benchmark(batch_size, provider):
|
|||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
if provider == "cuda":
|
if provider == "cuda":
|
||||||
|
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
||||||
|
|
||||||
def run_cuda():
|
def run_cuda():
|
||||||
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
|
||||||
ep_moe_pre_reorder(
|
ep_moe_pre_reorder(
|
||||||
inp,
|
inp,
|
||||||
gout,
|
gout,
|
||||||
@@ -71,9 +74,9 @@ def benchmark(batch_size, provider):
|
|||||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||||
|
|
||||||
elif provider == "triton":
|
elif provider == "triton":
|
||||||
|
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
||||||
|
|
||||||
def run_triton():
|
def run_triton():
|
||||||
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
|
||||||
pre_reorder_triton_kernel[(batch_size,)](
|
pre_reorder_triton_kernel[(batch_size,)](
|
||||||
inp.view(-1),
|
inp.view(-1),
|
||||||
gout.view(-1),
|
gout.view(-1),
|
||||||
|
|||||||
@@ -7,9 +7,10 @@
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
__global__ void ep_pre_reorder_cuda_kernel(
|
__global__ void ep_pre_reorder_cuda_kernel(
|
||||||
const float* __restrict__ input_ptr,
|
const scalar_t* __restrict__ input_ptr,
|
||||||
float* __restrict__ gateup_input_ptr,
|
scalar_t* __restrict__ gateup_input_ptr,
|
||||||
const int* __restrict__ src2dst_ptr,
|
const int* __restrict__ src2dst_ptr,
|
||||||
const int* __restrict__ topk_ids_ptr,
|
const int* __restrict__ topk_ids_ptr,
|
||||||
const float* __restrict__ a1_scales_ptr,
|
const float* __restrict__ a1_scales_ptr,
|
||||||
@@ -21,20 +22,20 @@ __global__ void ep_pre_reorder_cuda_kernel(
|
|||||||
int token_idx = blockIdx.x;
|
int token_idx = blockIdx.x;
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
const float* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
|
const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
|
||||||
const int* token_src2dst = src2dst_ptr + token_idx * topk;
|
const int* token_src2dst = src2dst_ptr + token_idx * topk;
|
||||||
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
|
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
|
||||||
|
if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
|
||||||
|
scale = 1.0f / a1_scales_ptr[token_idx];
|
||||||
|
}
|
||||||
|
|
||||||
for (int k = 0; k < topk; ++k) {
|
for (int k = 0; k < topk; ++k) {
|
||||||
int expert_id = token_topk_ids[k];
|
int expert_id = token_topk_ids[k];
|
||||||
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
|
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
|
||||||
|
|
||||||
float scale = 1.0f;
|
|
||||||
|
|
||||||
if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
|
|
||||||
scale = 1.0f / a1_scales_ptr[token_idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (a1_scales_ptr != nullptr) {
|
if (a1_scales_ptr != nullptr) {
|
||||||
if (!use_per_token_if_dynamic) {
|
if (!use_per_token_if_dynamic) {
|
||||||
scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
|
scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
|
||||||
@@ -42,21 +43,27 @@ __global__ void ep_pre_reorder_cuda_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
int dst_idx = token_src2dst[k];
|
int dst_idx = token_src2dst[k];
|
||||||
float* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
|
scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
|
||||||
|
|
||||||
constexpr uint32_t vec_size = 16 / sizeof(float);
|
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||||
using vec_t = flashinfer::vec_t<float, vec_size>;
|
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||||
|
|
||||||
|
int vec_elements = (hidden_size / vec_size) * vec_size;
|
||||||
for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
|
for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
|
||||||
vec_t input_vec, output_vec;
|
vec_t input_vec, output_vec;
|
||||||
input_vec.cast_load(src_ptr + idx * vec_size);
|
input_vec.cast_load(src_ptr + idx * vec_size);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||||
float val = static_cast<float>(input_vec[i]);
|
float val = static_cast<float>(input_vec[i]);
|
||||||
output_vec[i] = val * scale;
|
output_vec[i] = static_cast<scalar_t>(val * scale);
|
||||||
}
|
}
|
||||||
output_vec.cast_store(dst_ptr + idx * vec_size);
|
output_vec.cast_store(dst_ptr + idx * vec_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float val = static_cast<float>(src_ptr[idx]);
|
||||||
|
dst_ptr[idx] = static_cast<scalar_t>(val * scale);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,15 +82,19 @@ void ep_moe_pre_reorder(
|
|||||||
dim3 grid(total_blocks);
|
dim3 grid(total_blocks);
|
||||||
dim3 block(block_size);
|
dim3 block(block_size);
|
||||||
int hidden_size = input.size(1);
|
int hidden_size = input.size(1);
|
||||||
ep_pre_reorder_cuda_kernel<<<grid, block>>>(
|
|
||||||
input.data_ptr<float>(),
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
||||||
gateup_input.data_ptr<float>(),
|
ep_pre_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
|
||||||
src2dst.data_ptr<int>(),
|
static_cast<scalar_t*>(input.data_ptr()),
|
||||||
topk_ids.data_ptr<int>(),
|
static_cast<scalar_t*>(gateup_input.data_ptr()),
|
||||||
a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
|
src2dst.data_ptr<int>(),
|
||||||
start_expert_id,
|
topk_ids.data_ptr<int>(),
|
||||||
end_expert_id,
|
a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
|
||||||
topk,
|
start_expert_id,
|
||||||
hidden_size,
|
end_expert_id,
|
||||||
use_per_token_if_dynamic);
|
topk,
|
||||||
|
hidden_size,
|
||||||
|
use_per_token_if_dynamic);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
181
sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py
Normal file
181
sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import ep_moe_pre_reorder
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_tensors(
|
||||||
|
batch_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
use_per_token_if_dynamic: bool = True,
|
||||||
|
):
|
||||||
|
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Ensure src2dst has no duplicate destinations to avoid race conditions
|
||||||
|
total_tokens = batch_size * topk
|
||||||
|
dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32)
|
||||||
|
src2dst = dst_indices.view(batch_size, topk)
|
||||||
|
|
||||||
|
topk_ids = torch.randint(
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id + 1,
|
||||||
|
(batch_size, topk),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_per_token_if_dynamic:
|
||||||
|
a1_scales = (
|
||||||
|
torch.rand(batch_size, dtype=torch.float32, device=device) * 0.8 + 0.2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
a1_scales = (
|
||||||
|
torch.rand(
|
||||||
|
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
* 0.8
|
||||||
|
+ 0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_tensor, src2dst, topk_ids, a1_scales
|
||||||
|
|
||||||
|
|
||||||
|
def run_cuda_kernel(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
gateup_input: torch.Tensor,
|
||||||
|
src2dst: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
a1_scales: torch.Tensor,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
topk: int,
|
||||||
|
use_per_token_if_dynamic: bool,
|
||||||
|
):
|
||||||
|
ep_moe_pre_reorder(
|
||||||
|
input_tensor,
|
||||||
|
gateup_input,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
a1_scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
return gateup_input
|
||||||
|
|
||||||
|
|
||||||
|
def run_triton_kernel(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
gateup_input: torch.Tensor,
|
||||||
|
src2dst: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
a1_scales: torch.Tensor,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
topk: int,
|
||||||
|
hidden_size: int,
|
||||||
|
use_per_token_if_dynamic: bool,
|
||||||
|
):
|
||||||
|
batch_size = input_tensor.size(0)
|
||||||
|
block_size = 512
|
||||||
|
|
||||||
|
pre_reorder_triton_kernel[(batch_size,)](
|
||||||
|
input_tensor,
|
||||||
|
gateup_input,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
a1_scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
hidden_size,
|
||||||
|
block_size,
|
||||||
|
use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
return gateup_input
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,hidden_size,topk",
|
||||||
|
list(itertools.product([32, 64, 128], [512, 1024, 2048], [4, 8])),
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
@pytest.mark.parametrize("use_per_token_if_dynamic", [True, False])
|
||||||
|
def test_ep_moe_pre_reorder_vs_triton(
|
||||||
|
batch_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_per_token_if_dynamic: bool,
|
||||||
|
):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
start_expert_id = 0
|
||||||
|
end_expert_id = 15
|
||||||
|
|
||||||
|
(
|
||||||
|
input_tensor,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
a1_scales,
|
||||||
|
) = create_test_tensors(
|
||||||
|
batch_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
|
||||||
|
gateup_input_cuda = torch.empty(
|
||||||
|
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
gateup_input_triton = torch.empty(
|
||||||
|
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_output = run_cuda_kernel(
|
||||||
|
input_tensor,
|
||||||
|
gateup_input_cuda,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
a1_scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_output = run_triton_kernel(
|
||||||
|
input_tensor,
|
||||||
|
gateup_input_triton,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
a1_scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
hidden_size,
|
||||||
|
use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
cuda_output.float(),
|
||||||
|
triton_output.float(),
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user