diff --git a/.gitignore b/.gitignore index 4427e812a..6470c7df5 100644 --- a/.gitignore +++ b/.gitignore @@ -185,3 +185,36 @@ work_dirs/ *.csv !logo.png + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 41f27d07d..4330d2e19 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,37 +1,34 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] +requires = ["setuptools>=61.0", "wheel", "torch"] build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.1" +version = "0.0.2" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" license = { file = "LICENSE" } classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: C++", + "Programming Language :: CUDA", +] +dependencies = [ + "torch", ] -dependencies = ["numpy"] - -[project.optional-dependencies] -srt = ["torch"] - -all = ["sgl-kernel[srt]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" -[tool.setuptools.packages.find] -exclude = [ - "dist*", - "tests*", -] +[tool.setuptools] +package-dir = {"sgl_kernel" = "src/sgl-kernel"} +packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"] [tool.wheel] exclude = [ - "dist*", - "tests*", + "dist*", + "tests*", ] diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py new file mode 100644 index 000000000..4702c7f20 --- /dev/null +++ b/sgl-kernel/setup.py @@ -0,0 +1,20 @@ +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="sgl-kernel", + version="0.0.2", + packages=find_packages(where="src"), + package_dir={"": "src"}, + ext_modules=[ + CUDAExtension( + "sgl_kernel.ops.warp_reduce_cuda", + [ + "src/sgl-kernel/csrc/warp_reduce.cc", + "src/sgl-kernel/csrc/warp_reduce_kernel.cu", + ], + ) + ], + cmdclass={"build_ext": BuildExtension}, + install_requires=["torch"], +) diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index e69de29bb..edf3921db 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -0,0 +1,3 @@ +from .ops import warp_reduce + +__all__ = ["warp_reduce"] diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc new file mode 100644 index 000000000..66033c9d2 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc @@ -0,0 +1,21 @@ +#include +#include + +torch::Tensor warp_reduce_cuda(torch::Tensor input); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor warp_reduce(torch::Tensor input) { + CHECK_INPUT(input); + return warp_reduce_cuda(input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)"); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu new file mode 100644 index 000000000..c547682f6 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu @@ -0,0 +1,97 @@ +#include +#include +#include + +#define FINAL_MASK 0xffffffff +#define BLOCK_SIZE 256 + +template +__device__ __forceinline__ scalar_t add(scalar_t a, scalar_t b) { + return a + b; +} + +template +__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(FINAL_MASK, val, offset); + } + return val; +} + +template +__device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) { + __shared__ scalar_t shared[32]; + int lane = threadIdx.x % 32; + int wid = threadIdx.x / 32; + + val = warpReduceSum(val); // First reduce within warp + + if (lane == 0) + shared[wid] = val; // Write reduced value to shared memory + + __syncthreads(); // Wait for all partial reductions + + // Read from shared memory only if that warp existed + val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0; + + if (wid == 0) + val = warpReduceSum(val); // Final reduce within first warp + + return val; +} + +template +__global__ void warp_reduce_cuda_kernel( + const torch::PackedTensorAccessor32 + input, + torch::PackedTensorAccessor32 output, + int N) { + + scalar_t sum = 0; + + // Grid-stride loop + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + sum += input[i]; + } + + // Perform block-wide reduction + sum = blockReduceSum(sum); + + // Write result for this block to global memory + if (threadIdx.x == 0) { + output[blockIdx.x] = sum; + } +} + +torch::Tensor warp_reduce_cuda(torch::Tensor input) { + // Input validation + TORCH_CHECK(input.dim() == 1, "1D tensor expected"); + TORCH_CHECK(input.is_cuda(), "CUDA tensor expected"); + + const auto N = input.size(0); + + // Handle empty tensor + if (N == 0) { + return torch::zeros({1}, input.options()); + } + + // Calculate grid dimensions + const int threads = BLOCK_SIZE; + const int blocks = (N + threads - 1) / threads; + + // Allocate output tensor for partial sums + auto output = torch::empty({blocks}, input.options()); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "warp_reduce_cuda", ([&] { + warp_reduce_cuda_kernel<<>>( + input.packed_accessor32(), + output.packed_accessor32(), + N); + })); + + // Sum the partial results + return output.sum(); +} diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py new file mode 100644 index 000000000..21870032e --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -0,0 +1,5 @@ +from .warp_reduce_cuda import reduce as _reduce + + +def warp_reduce(input_tensor): + return _reduce(input_tensor)