feat: use warp reduce as a simple example (#2304)
This commit is contained in:
33
.gitignore
vendored
33
.gitignore
vendored
@@ -185,3 +185,36 @@ work_dirs/
|
|||||||
*.csv
|
*.csv
|
||||||
|
|
||||||
!logo.png
|
!logo.png
|
||||||
|
|
||||||
|
# Prerequisites
|
||||||
|
*.d
|
||||||
|
|
||||||
|
# Compiled Object files
|
||||||
|
*.slo
|
||||||
|
*.lo
|
||||||
|
*.o
|
||||||
|
*.obj
|
||||||
|
|
||||||
|
# Precompiled Headers
|
||||||
|
*.gch
|
||||||
|
*.pch
|
||||||
|
|
||||||
|
# Compiled Dynamic libraries
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
*.dll
|
||||||
|
|
||||||
|
# Fortran module files
|
||||||
|
*.mod
|
||||||
|
*.smod
|
||||||
|
|
||||||
|
# Compiled Static libraries
|
||||||
|
*.lai
|
||||||
|
*.la
|
||||||
|
*.a
|
||||||
|
*.lib
|
||||||
|
|
||||||
|
# Executables
|
||||||
|
*.exe
|
||||||
|
*.out
|
||||||
|
*.app
|
||||||
|
|||||||
@@ -1,37 +1,34 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61.0", "wheel"]
|
requires = ["setuptools>=61.0", "wheel", "torch"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.0.1"
|
version = "0.0.2"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Programming Language :: C++",
|
||||||
|
"Programming Language :: CUDA",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
]
|
]
|
||||||
dependencies = ["numpy"]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
srt = ["torch"]
|
|
||||||
|
|
||||||
all = ["sgl-kernel[srt]"]
|
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools]
|
||||||
exclude = [
|
package-dir = {"sgl_kernel" = "src/sgl-kernel"}
|
||||||
"dist*",
|
packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"]
|
||||||
"tests*",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.wheel]
|
[tool.wheel]
|
||||||
exclude = [
|
exclude = [
|
||||||
"dist*",
|
"dist*",
|
||||||
"tests*",
|
"tests*",
|
||||||
]
|
]
|
||||||
|
|||||||
20
sgl-kernel/setup.py
Normal file
20
sgl-kernel/setup.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from setuptools import find_packages, setup
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="sgl-kernel",
|
||||||
|
version="0.0.2",
|
||||||
|
packages=find_packages(where="src"),
|
||||||
|
package_dir={"": "src"},
|
||||||
|
ext_modules=[
|
||||||
|
CUDAExtension(
|
||||||
|
"sgl_kernel.ops.warp_reduce_cuda",
|
||||||
|
[
|
||||||
|
"src/sgl-kernel/csrc/warp_reduce.cc",
|
||||||
|
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
install_requires=["torch"],
|
||||||
|
)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .ops import warp_reduce
|
||||||
|
|
||||||
|
__all__ = ["warp_reduce"]
|
||||||
|
|||||||
21
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
Normal file
21
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) \
|
||||||
|
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
torch::Tensor warp_reduce(torch::Tensor input) {
|
||||||
|
CHECK_INPUT(input);
|
||||||
|
return warp_reduce_cuda(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)");
|
||||||
|
}
|
||||||
97
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
Normal file
97
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#define FINAL_MASK 0xffffffff
|
||||||
|
#define BLOCK_SIZE 256
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ __forceinline__ scalar_t add(scalar_t a, scalar_t b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset /= 2) {
|
||||||
|
val += __shfl_down_sync(FINAL_MASK, val, offset);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
|
||||||
|
__shared__ scalar_t shared[32];
|
||||||
|
int lane = threadIdx.x % 32;
|
||||||
|
int wid = threadIdx.x / 32;
|
||||||
|
|
||||||
|
val = warpReduceSum(val); // First reduce within warp
|
||||||
|
|
||||||
|
if (lane == 0)
|
||||||
|
shared[wid] = val; // Write reduced value to shared memory
|
||||||
|
|
||||||
|
__syncthreads(); // Wait for all partial reductions
|
||||||
|
|
||||||
|
// Read from shared memory only if that warp existed
|
||||||
|
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
|
||||||
|
|
||||||
|
if (wid == 0)
|
||||||
|
val = warpReduceSum(val); // Final reduce within first warp
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void warp_reduce_cuda_kernel(
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
|
||||||
|
input,
|
||||||
|
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
|
||||||
|
int N) {
|
||||||
|
|
||||||
|
scalar_t sum = 0;
|
||||||
|
|
||||||
|
// Grid-stride loop
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
||||||
|
i += blockDim.x * gridDim.x) {
|
||||||
|
sum += input[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform block-wide reduction
|
||||||
|
sum = blockReduceSum(sum);
|
||||||
|
|
||||||
|
// Write result for this block to global memory
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
output[blockIdx.x] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor warp_reduce_cuda(torch::Tensor input) {
|
||||||
|
// Input validation
|
||||||
|
TORCH_CHECK(input.dim() == 1, "1D tensor expected");
|
||||||
|
TORCH_CHECK(input.is_cuda(), "CUDA tensor expected");
|
||||||
|
|
||||||
|
const auto N = input.size(0);
|
||||||
|
|
||||||
|
// Handle empty tensor
|
||||||
|
if (N == 0) {
|
||||||
|
return torch::zeros({1}, input.options());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate grid dimensions
|
||||||
|
const int threads = BLOCK_SIZE;
|
||||||
|
const int blocks = (N + threads - 1) / threads;
|
||||||
|
|
||||||
|
// Allocate output tensor for partial sums
|
||||||
|
auto output = torch::empty({blocks}, input.options());
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "warp_reduce_cuda", ([&] {
|
||||||
|
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
|
||||||
|
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
||||||
|
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
||||||
|
N);
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Sum the partial results
|
||||||
|
return output.sum();
|
||||||
|
}
|
||||||
5
sgl-kernel/src/sgl-kernel/ops/__init__.py
Normal file
5
sgl-kernel/src/sgl-kernel/ops/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .warp_reduce_cuda import reduce as _reduce
|
||||||
|
|
||||||
|
|
||||||
|
def warp_reduce(input_tensor):
|
||||||
|
return _reduce(input_tensor)
|
||||||
Reference in New Issue
Block a user