update 3rdparty and rms norm for sgl-kernel (#3213)
This commit is contained in:
2
sgl-kernel/3rdparty/cutlass
vendored
2
sgl-kernel/3rdparty/cutlass
vendored
Submodule sgl-kernel/3rdparty/cutlass updated: b78588d163...bdd641790a
2
sgl-kernel/3rdparty/flashinfer
vendored
2
sgl-kernel/3rdparty/flashinfer
vendored
Submodule sgl-kernel/3rdparty/flashinfer updated: 4f1f08989c...e5a3befbe3
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.0.3"
|
version = "0.0.3.post1"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1,116 +1,11 @@
|
|||||||
// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh
|
|
||||||
// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu
|
|
||||||
// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0
|
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#include <flashinfer/math.cuh>
|
#include <flashinfer/norm.cuh>
|
||||||
#include <flashinfer/utils.cuh>
|
|
||||||
#include <flashinfer/vec_dtypes.cuh>
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
using namespace flashinfer;
|
using namespace flashinfer;
|
||||||
|
|
||||||
template <uint32_t VEC_SIZE, typename T>
|
|
||||||
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight,
|
|
||||||
const uint32_t d, float eps) {
|
|
||||||
const uint32_t bx = blockIdx.x;
|
|
||||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
|
||||||
constexpr uint32_t warp_size = 32;
|
|
||||||
const uint32_t num_warps = blockDim.y;
|
|
||||||
const uint32_t thread_id = tx + ty * warp_size;
|
|
||||||
const uint32_t num_threads = num_warps * warp_size;
|
|
||||||
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
|
|
||||||
extern __shared__ float smem[];
|
|
||||||
|
|
||||||
float sum_sq = 0.f;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < rounds; i++) {
|
|
||||||
vec_t<T, VEC_SIZE> input_vec;
|
|
||||||
input_vec.fill(0.f);
|
|
||||||
vec_t<T, VEC_SIZE> residual_vec;
|
|
||||||
residual_vec.fill(0.f);
|
|
||||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
||||||
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
|
||||||
float x = float(input_vec[j]);
|
|
||||||
x += float(residual_vec[j]);
|
|
||||||
sum_sq += x * x;
|
|
||||||
residual_vec[j] = (T)x;
|
|
||||||
}
|
|
||||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
||||||
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// first, warp reduce sum
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
|
||||||
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
smem[ty] = sum_sq;
|
|
||||||
__syncthreads();
|
|
||||||
// then, cross warp reduce sum using only the first warp
|
|
||||||
if (ty == 0) {
|
|
||||||
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
|
||||||
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
|
||||||
}
|
|
||||||
smem[0] = sum_sq;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < rounds; i++) {
|
|
||||||
vec_t<T, VEC_SIZE> input_vec;
|
|
||||||
vec_t<T, VEC_SIZE> weight_vec;
|
|
||||||
vec_t<T, VEC_SIZE> residual_vec;
|
|
||||||
input_vec.fill(0.f);
|
|
||||||
weight_vec.fill(0.f);
|
|
||||||
residual_vec.fill(0.f);
|
|
||||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
||||||
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
|
||||||
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
|
|
||||||
}
|
|
||||||
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
||||||
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5,
|
|
||||||
cudaStream_t stream = 0) {
|
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
|
||||||
|
|
||||||
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
|
|
||||||
const uint32_t num_warps = ceil_div(block_size, 32);
|
|
||||||
dim3 nblks(batch_size);
|
|
||||||
dim3 nthrs(32, num_warps);
|
|
||||||
const uint32_t smem_size = num_warps * sizeof(float);
|
|
||||||
void* args[] = {&input, &residual, &weight, &d, &eps};
|
|
||||||
|
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
|
||||||
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
|
|
||||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
|
||||||
});
|
|
||||||
|
|
||||||
return cudaSuccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
||||||
CHECK_INPUT(input);
|
CHECK_INPUT(input);
|
||||||
CHECK_INPUT(residual);
|
CHECK_INPUT(residual);
|
||||||
@@ -130,9 +25,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
|
|||||||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||||
// support float16, bfloat16 and float32
|
// support float16, bfloat16 and float32
|
||||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
||||||
cudaError_t status =
|
cudaError_t status = norm::FusedAddRMSNorm(
|
||||||
FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
|
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
|
||||||
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
|
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
|
||||||
TORCH_CHECK(status == cudaSuccess,
|
TORCH_CHECK(status == cudaSuccess,
|
||||||
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "0.0.3"
|
__version__ = "0.0.3.post1"
|
||||||
|
|||||||
Reference in New Issue
Block a user