format: add clang-format for sgl-kernel (#2483)
This commit is contained in:
8
sgl-kernel/.clang-format
Normal file
8
sgl-kernel/.clang-format
Normal file
@@ -0,0 +1,8 @@
|
||||
BasedOnStyle: Google
|
||||
IndentWidth: 2
|
||||
ColumnLimit: 120
|
||||
AllowShortFunctionsOnASingleLine: Empty
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
NamespaceIndentation: None
|
||||
SortIncludes: true
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: tree ln install build clean test
|
||||
.PHONY: tree ln install build clean test format
|
||||
|
||||
tree:
|
||||
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
|
||||
@@ -17,3 +17,6 @@ clean:
|
||||
|
||||
test:
|
||||
@pytest tests/
|
||||
|
||||
format:
|
||||
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor warp_reduce(torch::Tensor input) {
|
||||
|
||||
@@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
|
||||
int lane = threadIdx.x % 32;
|
||||
int wid = threadIdx.x / 32;
|
||||
|
||||
val = warpReduceSum(val); // First reduce within warp
|
||||
val = warpReduceSum(val); // First reduce within warp
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val; // Write reduced value to shared memory
|
||||
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
|
||||
|
||||
__syncthreads(); // Wait for all partial reductions
|
||||
__syncthreads(); // Wait for all partial reductions
|
||||
|
||||
// Read from shared memory only if that warp existed
|
||||
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
|
||||
|
||||
if (wid == 0)
|
||||
val = warpReduceSum(val); // Final reduce within first warp
|
||||
if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void warp_reduce_cuda_kernel(
|
||||
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
|
||||
input,
|
||||
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
|
||||
int N) {
|
||||
|
||||
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
|
||||
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
|
||||
scalar_t sum = 0;
|
||||
|
||||
// Grid-stride loop
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
|
||||
sum += input[i];
|
||||
}
|
||||
|
||||
@@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) {
|
||||
// Allocate output tensor for partial sums
|
||||
auto output = torch::empty({blocks}, input.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "warp_reduce_cuda", ([&] {
|
||||
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
|
||||
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
||||
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
||||
N);
|
||||
}));
|
||||
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
|
||||
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
|
||||
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
||||
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
|
||||
}));
|
||||
|
||||
// Sum the partial results
|
||||
return output.sum();
|
||||
|
||||
Reference in New Issue
Block a user