diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index cf3c6a563..56c5b1bb5 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -38,10 +38,10 @@ def _get_version(): return line.split("=")[1].strip().strip('"') -cutlass = root / "3rdparty" / "cutlass" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" +turbomind = root / "3rdparty" / "turbomind" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -49,6 +49,8 @@ include_dirs = [ flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", + turbomind.resolve(), + turbomind.resolve() / "src", ] nvcc_flags = [ "-DNDEBUG", @@ -63,6 +65,11 @@ nvcc_flags = [ "-use_fast_math", "-DFLASHINFER_ENABLE_F16", ] +nvcc_flags_fp8 = [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", +] sources = [ "src/sgl-kernel/csrc/trt_reduce_internal.cu", @@ -73,6 +80,7 @@ sources = [ "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", + "src/sgl-kernel/csrc/fused_add_rms_norm.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/group_gemm.cu", @@ -92,13 +100,7 @@ if torch.cuda.is_available(): nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") if sm_version >= 90: - nvcc_flags.extend( - [ - "-DFLASHINFER_ENABLE_FP8", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] - ) + nvcc_flags.extend(nvcc_flags_fp8) if sm_version >= 80: nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") else: @@ -107,13 +109,7 @@ else: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") if enable_fp8: - nvcc_flags.extend( - [ - "-DFLASHINFER_ENABLE_FP8", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] - ) + nvcc_flags.extend(nvcc_flags_fp8) if enable_bf16: nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu new file mode 100644 index 000000000..734061586 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu @@ -0,0 +1,92 @@ +// Adapted from +// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu + +#include +#include + +#include + +using namespace turbomind; + +template +__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual, T* __restrict__ hidden_states, + const T* __restrict__ weights, const T* __restrict__ bias, int dims, int num, + float eps, float inv_dims) { + const int ti = blockIdx.x; + const int di = threadIdx.x * vec_size; + + if (ti >= num) { + return; + } + + residual += dims * ti; + hidden_states += dims * ti; + + Array accum{}; + + Array r_vec; + Array h_vec; + Array b_vec; + + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(r_vec, &residual[i]); + Load(h_vec, &hidden_states[i]); + + using namespace ops; + r_vec = r_vec + h_vec; + + if (bias) { + Ldg(b_vec, &bias[i]); + r_vec = r_vec + b_vec; + } + + Store(&residual[i], r_vec); + + Array tmp = cast(r_vec); + + accum = accum + tmp * tmp; + } + + float sum{}; + PRAGMA_UNROLL + for (int i = 0; i < vec_size; ++i) { + sum += accum[i]; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + sum = BlockReduce{temp_storage}.Sum(sum); + + __shared__ float shared_sum; + + if (threadIdx.x == 0) { + shared_sum = rsqrtf(sum * inv_dims + eps); + } + + __syncthreads(); + + sum = shared_sum; + + Array w_vec; + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(r_vec, &residual[i]); + Ldg(w_vec, &weights[i]); + PRAGMA_UNROLL + for (int c = 0; c < vec_size; ++c) { + r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c]; + } + Store(&hidden_states[i], r_vec); + } +} + +template +void invokeBiasResidualRMSNorm(T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, + float eps, cudaStream_t st) { + constexpr int vec_size = 16 / sizeof(T); + constexpr int threads = 512; + const int blocks = num; + + BiasResidualRMSNormKernel + <<>>(residual, hidden_states, weights, bias, dims, num, eps, 1.f / dims); +}