minor: cleanup sgl-kernel (#2679)
This commit is contained in:
@@ -28,7 +28,6 @@ find_package(Torch REQUIRED)
|
|||||||
|
|
||||||
# Warp Reduce library
|
# Warp Reduce library
|
||||||
add_library(_kernels SHARED
|
add_library(_kernels SHARED
|
||||||
src/sgl-kernel/csrc/warp_reduce_kernel.cu
|
|
||||||
src/sgl-kernel/csrc/trt_reduce_internal.cu
|
src/sgl-kernel/csrc/trt_reduce_internal.cu
|
||||||
src/sgl-kernel/csrc/trt_reduce_kernel.cu
|
src/sgl-kernel/csrc/trt_reduce_kernel.cu
|
||||||
src/sgl-kernel/csrc/moe_align_kernel.cu
|
src/sgl-kernel/csrc/moe_align_kernel.cu
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
import os
|
|
||||||
import shutil
|
|
||||||
import zipfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
@@ -16,39 +13,6 @@ def get_version():
|
|||||||
return line.split("=")[1].strip().strip('"')
|
return line.split("=")[1].strip().strip('"')
|
||||||
|
|
||||||
|
|
||||||
def rename_wheel():
|
|
||||||
if not os.environ.get("CUDA_VERSION"):
|
|
||||||
return
|
|
||||||
cuda_version = os.environ["CUDA_VERSION"].replace(".", "")
|
|
||||||
base_version = get_version()
|
|
||||||
|
|
||||||
wheel_dir = Path("dist")
|
|
||||||
old_wheel = next(wheel_dir.glob("*.whl"))
|
|
||||||
tmp_dir = wheel_dir / "tmp"
|
|
||||||
tmp_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(old_wheel, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(tmp_dir)
|
|
||||||
|
|
||||||
old_info = tmp_dir / f"sgl_kernel-{base_version}.dist-info"
|
|
||||||
new_info = tmp_dir / f"sgl_kernel-{base_version}.post0+cu{cuda_version}.dist-info"
|
|
||||||
old_info.rename(new_info)
|
|
||||||
|
|
||||||
platform = "manylinux2014_x86_64"
|
|
||||||
new_wheel = wheel_dir / old_wheel.name.replace("linux_x86_64", platform)
|
|
||||||
new_wheel = wheel_dir / new_wheel.name.replace(
|
|
||||||
base_version, f"{base_version}.post0+cu{cuda_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(new_wheel, "w", zipfile.ZIP_DEFLATED) as new_zip:
|
|
||||||
for file_path in tmp_dir.rglob("*"):
|
|
||||||
if file_path.is_file():
|
|
||||||
new_zip.write(file_path, file_path.relative_to(tmp_dir))
|
|
||||||
|
|
||||||
old_wheel.unlink()
|
|
||||||
shutil.rmtree(tmp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def update_wheel_platform_tag():
|
def update_wheel_platform_tag():
|
||||||
wheel_dir = Path("dist")
|
wheel_dir = Path("dist")
|
||||||
old_wheel = next(wheel_dir.glob("*.whl"))
|
old_wheel = next(wheel_dir.glob("*.whl"))
|
||||||
@@ -81,7 +45,6 @@ ext_modules = [
|
|||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="sgl_kernel.ops._kernels",
|
name="sgl_kernel.ops._kernels",
|
||||||
sources=[
|
sources=[
|
||||||
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
|
|
||||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||||
|
|||||||
@@ -3,12 +3,10 @@ from sgl_kernel.ops import (
|
|||||||
custom_reduce,
|
custom_reduce,
|
||||||
init_custom_reduce,
|
init_custom_reduce,
|
||||||
moe_align_block_size,
|
moe_align_block_size,
|
||||||
warp_reduce,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"moe_align_block_size",
|
"moe_align_block_size",
|
||||||
"warp_reduce",
|
|
||||||
"init_custom_reduce",
|
"init_custom_reduce",
|
||||||
"custom_dispose",
|
"custom_dispose",
|
||||||
"custom_reduce",
|
"custom_reduce",
|
||||||
|
|||||||
@@ -1,13 +1,5 @@
|
|||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
// warp_reduce
|
|
||||||
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
|
||||||
|
|
||||||
torch::Tensor warp_reduce(torch::Tensor input) {
|
|
||||||
CHECK_CUDA_INPUT(input);
|
|
||||||
return warp_reduce_cuda(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
// trt_reduce
|
// trt_reduce
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
|
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
|
||||||
@@ -21,8 +13,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
|||||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
// warp_reduce
|
|
||||||
m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)");
|
|
||||||
// trt_reduce
|
// trt_reduce
|
||||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||||
m.def("dispose", &dispose, "dispose custom allreduce meta");
|
m.def("dispose", &dispose, "dispose custom allreduce meta");
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
#define FINAL_MASK 0xffffffff
|
|
||||||
#define BLOCK_SIZE 256
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__ __forceinline__ scalar_t add(scalar_t a, scalar_t b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = 16; offset > 0; offset /= 2) {
|
|
||||||
val += __shfl_down_sync(FINAL_MASK, val, offset);
|
|
||||||
}
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
|
|
||||||
__shared__ scalar_t shared[32];
|
|
||||||
int lane = threadIdx.x % 32;
|
|
||||||
int wid = threadIdx.x / 32;
|
|
||||||
|
|
||||||
val = warpReduceSum(val); // First reduce within warp
|
|
||||||
|
|
||||||
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
|
|
||||||
|
|
||||||
__syncthreads(); // Wait for all partial reductions
|
|
||||||
|
|
||||||
// Read from shared memory only if that warp existed
|
|
||||||
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
|
|
||||||
|
|
||||||
if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp
|
|
||||||
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__ void warp_reduce_cuda_kernel(
|
|
||||||
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
|
|
||||||
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
|
|
||||||
scalar_t sum = 0;
|
|
||||||
|
|
||||||
// Grid-stride loop
|
|
||||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
|
|
||||||
sum += input[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform block-wide reduction
|
|
||||||
sum = blockReduceSum(sum);
|
|
||||||
|
|
||||||
// Write result for this block to global memory
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
output[blockIdx.x] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor warp_reduce_cuda(torch::Tensor input) {
|
|
||||||
// Input validation
|
|
||||||
TORCH_CHECK(input.dim() == 1, "1D tensor expected");
|
|
||||||
TORCH_CHECK(input.is_cuda(), "CUDA tensor expected");
|
|
||||||
|
|
||||||
const auto N = input.size(0);
|
|
||||||
|
|
||||||
// Handle empty tensor
|
|
||||||
if (N == 0) {
|
|
||||||
return torch::zeros({1}, input.options());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate grid dimensions
|
|
||||||
const int threads = BLOCK_SIZE;
|
|
||||||
const int blocks = (N + threads - 1) / threads;
|
|
||||||
|
|
||||||
// Allocate output tensor for partial sums
|
|
||||||
auto output = torch::empty({blocks}, input.options());
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
|
|
||||||
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
|
|
||||||
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
|
||||||
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Sum the partial results
|
|
||||||
return output.sum();
|
|
||||||
}
|
|
||||||
@@ -2,11 +2,6 @@ from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
|||||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||||
from sgl_kernel.ops._kernels import reduce as _reduce
|
|
||||||
|
|
||||||
|
|
||||||
def warp_reduce(input_tensor):
|
|
||||||
return _reduce(input_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
|
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
|
||||||
|
|||||||
Reference in New Issue
Block a user