format: add clang-format for sgl-kernel (#2483)
This commit is contained in:
8
sgl-kernel/.clang-format
Normal file
8
sgl-kernel/.clang-format
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
BasedOnStyle: Google
|
||||||
|
IndentWidth: 2
|
||||||
|
ColumnLimit: 120
|
||||||
|
AllowShortFunctionsOnASingleLine: Empty
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Left
|
||||||
|
NamespaceIndentation: None
|
||||||
|
SortIncludes: true
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: tree ln install build clean test
|
.PHONY: tree ln install build clean test format
|
||||||
|
|
||||||
tree:
|
tree:
|
||||||
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
|
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
|
||||||
@@ -17,3 +17,6 @@ clean:
|
|||||||
|
|
||||||
test:
|
test:
|
||||||
@pytest tests/
|
@pytest tests/
|
||||||
|
|
||||||
|
format:
|
||||||
|
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
|
||||||
|
|||||||
@@ -2,12 +2,10 @@
|
|||||||
|
|
||||||
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
||||||
|
|
||||||
#define CHECK_CUDA(x) \
|
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
#define CHECK_CONTIGUOUS(x) \
|
#define CHECK_INPUT(x) \
|
||||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
CHECK_CUDA(x); \
|
||||||
#define CHECK_INPUT(x) \
|
|
||||||
CHECK_CUDA(x); \
|
|
||||||
CHECK_CONTIGUOUS(x)
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
torch::Tensor warp_reduce(torch::Tensor input) {
|
torch::Tensor warp_reduce(torch::Tensor input) {
|
||||||
|
|||||||
@@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
|
|||||||
int lane = threadIdx.x % 32;
|
int lane = threadIdx.x % 32;
|
||||||
int wid = threadIdx.x / 32;
|
int wid = threadIdx.x / 32;
|
||||||
|
|
||||||
val = warpReduceSum(val); // First reduce within warp
|
val = warpReduceSum(val); // First reduce within warp
|
||||||
|
|
||||||
if (lane == 0)
|
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
|
||||||
shared[wid] = val; // Write reduced value to shared memory
|
|
||||||
|
|
||||||
__syncthreads(); // Wait for all partial reductions
|
__syncthreads(); // Wait for all partial reductions
|
||||||
|
|
||||||
// Read from shared memory only if that warp existed
|
// Read from shared memory only if that warp existed
|
||||||
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
|
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
|
||||||
|
|
||||||
if (wid == 0)
|
if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp
|
||||||
val = warpReduceSum(val); // Final reduce within first warp
|
|
||||||
|
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void warp_reduce_cuda_kernel(
|
__global__ void warp_reduce_cuda_kernel(
|
||||||
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
|
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
|
||||||
input,
|
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
|
||||||
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
|
|
||||||
int N) {
|
|
||||||
|
|
||||||
scalar_t sum = 0;
|
scalar_t sum = 0;
|
||||||
|
|
||||||
// Grid-stride loop
|
// Grid-stride loop
|
||||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
|
||||||
i += blockDim.x * gridDim.x) {
|
|
||||||
sum += input[i];
|
sum += input[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) {
|
|||||||
// Allocate output tensor for partial sums
|
// Allocate output tensor for partial sums
|
||||||
auto output = torch::empty({blocks}, input.options());
|
auto output = torch::empty({blocks}, input.options());
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
|
||||||
input.scalar_type(), "warp_reduce_cuda", ([&] {
|
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
|
||||||
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
|
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
||||||
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
|
||||||
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
|
}));
|
||||||
N);
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Sum the partial results
|
// Sum the partial results
|
||||||
return output.sum();
|
return output.sum();
|
||||||
|
|||||||
Reference in New Issue
Block a user