cleanup sgl-kernel kernels (#3175)
This commit is contained in:
1
.github/workflows/pr-test.yml
vendored
1
.github/workflows/pr-test.yml
vendored
@@ -51,6 +51,7 @@ jobs:
|
||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false
|
||||
runs-on: 1-gpu-runner
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100]
|
||||
steps:
|
||||
|
||||
@@ -88,7 +88,7 @@ sources = [
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/norm.cu",
|
||||
|
||||
@@ -17,7 +17,6 @@ from sgl_kernel.ops import (
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
rotary_embedding,
|
||||
sampling_scaling_penalties,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
@@ -44,7 +43,6 @@ __all__ = [
|
||||
"moe_align_block_size",
|
||||
"register_graph_buffers",
|
||||
"rmsnorm",
|
||||
"rotary_embedding",
|
||||
"sampling_scaling_penalties",
|
||||
"silu_and_mul",
|
||||
"top_k_renorm_prob",
|
||||
|
||||
140
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
Normal file
140
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
Normal file
@@ -0,0 +1,140 @@
|
||||
// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh
|
||||
// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu
|
||||
// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <flashinfer/math.cuh>
|
||||
#include <flashinfer/utils.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
#include <numeric>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
template <uint32_t VEC_SIZE, typename T>
|
||||
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight,
|
||||
const uint32_t d, float eps) {
|
||||
const uint32_t bx = blockIdx.x;
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
constexpr uint32_t warp_size = 32;
|
||||
const uint32_t num_warps = blockDim.y;
|
||||
const uint32_t thread_id = tx + ty * warp_size;
|
||||
const uint32_t num_threads = num_warps * warp_size;
|
||||
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
|
||||
extern __shared__ float smem[];
|
||||
|
||||
float sum_sq = 0.f;
|
||||
|
||||
for (uint32_t i = 0; i < rounds; i++) {
|
||||
vec_t<T, VEC_SIZE> input_vec;
|
||||
input_vec.fill(0.f);
|
||||
vec_t<T, VEC_SIZE> residual_vec;
|
||||
residual_vec.fill(0.f);
|
||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
||||
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
||||
float x = float(input_vec[j]);
|
||||
x += float(residual_vec[j]);
|
||||
sum_sq += x * x;
|
||||
residual_vec[j] = (T)x;
|
||||
}
|
||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
||||
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
// first, warp reduce sum
|
||||
#pragma unroll
|
||||
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
||||
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
||||
}
|
||||
|
||||
smem[ty] = sum_sq;
|
||||
__syncthreads();
|
||||
// then, cross warp reduce sum using only the first warp
|
||||
if (ty == 0) {
|
||||
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
|
||||
#pragma unroll
|
||||
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
||||
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
||||
}
|
||||
smem[0] = sum_sq;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
|
||||
|
||||
for (uint32_t i = 0; i < rounds; i++) {
|
||||
vec_t<T, VEC_SIZE> input_vec;
|
||||
vec_t<T, VEC_SIZE> weight_vec;
|
||||
vec_t<T, VEC_SIZE> residual_vec;
|
||||
input_vec.fill(0.f);
|
||||
weight_vec.fill(0.f);
|
||||
residual_vec.fill(0.f);
|
||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
||||
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
||||
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
|
||||
}
|
||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
||||
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
|
||||
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
|
||||
const uint32_t num_warps = ceil_div(block_size, 32);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(32, num_warps);
|
||||
const uint32_t smem_size = num_warps * sizeof(float);
|
||||
void* args[] = {&input, &residual, &weight, &d, &eps};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
});
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
auto device = input.device();
|
||||
CHECK_EQ(residual.device(), device);
|
||||
CHECK_EQ(weight.device(), device);
|
||||
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
|
||||
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
|
||||
CHECK_DIM(1, weight); // weight: (hidden_size)
|
||||
CHECK_EQ(input.size(0), residual.size(0));
|
||||
CHECK_EQ(input.size(1), residual.size(1));
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
unsigned int batch_size = input.size(0);
|
||||
unsigned int hidden_size = input.size(1);
|
||||
|
||||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||
// support float16, bfloat16 and float32
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||
cudaError_t status =
|
||||
FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
|
||||
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
|
||||
TORCH_CHECK(status == cudaSuccess,
|
||||
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
return true;
|
||||
});
|
||||
}
|
||||
@@ -50,15 +50,11 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
// rotary embedding
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
// rms norm
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
// fused rms norm
|
||||
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
||||
|
||||
// gemma rms norm
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
@@ -142,12 +142,6 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
)
|
||||
|
||||
|
||||
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||
return torch.ops.sgl_kernels.rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox
|
||||
)
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
@@ -167,9 +161,7 @@ def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
torch.ops.sgl_kernels.fused_add_rmsnorm(
|
||||
input, residual, weight, eps, _get_cuda_stream(device)
|
||||
)
|
||||
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
|
||||
@@ -45,19 +45,13 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"new_kv) -> ()");
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
|
||||
// rotary embedding
|
||||
m.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool "
|
||||
"is_neox) -> ()");
|
||||
m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||
|
||||
// rms norm
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
// fused rms norm
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm);
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
// gemma rms norm
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
eps = 1e-6
|
||||
|
||||
|
||||
Reference in New Issue
Block a user