misc: update build setup (#2306)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -218,3 +218,5 @@ work_dirs/
|
|||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
|
compile_commands.json
|
||||||
|
|||||||
19
sgl-kernel/Makefile
Normal file
19
sgl-kernel/Makefile
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
.PHONY: tree ln install build clean test
|
||||||
|
|
||||||
|
tree:
|
||||||
|
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
|
||||||
|
|
||||||
|
ln:
|
||||||
|
@rm -rf build && cmake . -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_CUDA_COMPILER=nvcc -B build && rm -rf compile_commands.json && ln -s build/compile_commands.json compile_commands.json
|
||||||
|
|
||||||
|
install:
|
||||||
|
@pip install -e .
|
||||||
|
|
||||||
|
build:
|
||||||
|
@python3 setup.py bdist_wheel
|
||||||
|
|
||||||
|
clean:
|
||||||
|
@rm -rf build dist *.egg-info
|
||||||
|
|
||||||
|
test:
|
||||||
|
@pytest tests/
|
||||||
13
sgl-kernel/build.sh
Executable file
13
sgl-kernel/build.sh
Executable file
@@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
docker run --rm -it \
|
||||||
|
-v "$(pwd)":/sgl-kernel \
|
||||||
|
pytorch/manylinux-builder:cuda12.1 \
|
||||||
|
bash -c "
|
||||||
|
pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121 && \
|
||||||
|
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
|
||||||
|
cd /sgl-kernel && \
|
||||||
|
python setup.py bdist_wheel
|
||||||
|
"
|
||||||
@@ -13,6 +13,18 @@ setup(
|
|||||||
"src/sgl-kernel/csrc/warp_reduce.cc",
|
"src/sgl-kernel/csrc/warp_reduce.cc",
|
||||||
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
|
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
|
||||||
],
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
"nvcc": [
|
||||||
|
"-O3",
|
||||||
|
"-Xcompiler",
|
||||||
|
"-fPIC",
|
||||||
|
"-gencode=arch=compute_75,code=sm_75",
|
||||||
|
"-gencode=arch=compute_80,code=sm_80",
|
||||||
|
"-gencode=arch=compute_89,code=sm_89",
|
||||||
|
"-gencode=arch=compute_90,code=sm_90",
|
||||||
|
],
|
||||||
|
"cxx": ["-O3"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
torch::Tensor warp_reduce_cuda(torch::Tensor input);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user