Refactor sgl-kernel build (#2642)
This commit is contained in:
@@ -58,78 +58,45 @@ def update_wheel_platform_tag():
|
||||
old_wheel.rename(new_wheel)
|
||||
|
||||
|
||||
nvcc_flags = [
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
"-gencode=arch=compute_75,code=sm_75",
|
||||
"-gencode=arch=compute_80,code=sm_80",
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
cxx_flags = ["-O3"]
|
||||
libraries = ["c10", "torch", "torch_python"]
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
|
||||
ext_modules = [
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.ops._kernels",
|
||||
sources=[
|
||||
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"nvcc": nvcc_flags,
|
||||
"cxx": cxx_flags,
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=get_version(),
|
||||
packages=["sgl_kernel"],
|
||||
package_dir={"": "src"},
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
"sgl_kernel.ops.warp_reduce_cuda",
|
||||
[
|
||||
"src/sgl-kernel/csrc/warp_reduce.cc",
|
||||
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
"-gencode=arch=compute_75,code=sm_75",
|
||||
"-gencode=arch=compute_80,code=sm_80",
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
],
|
||||
"cxx": ["-O3"],
|
||||
},
|
||||
libraries=["c10", "torch", "torch_python"],
|
||||
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
|
||||
),
|
||||
CUDAExtension(
|
||||
"sgl_kernel.ops.custom_reduce_cuda",
|
||||
[
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce.cc",
|
||||
],
|
||||
extra_compile_args={
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
"-gencode=arch=compute_75,code=sm_75",
|
||||
"-gencode=arch=compute_80,code=sm_80",
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
"cxx": ["-O3"],
|
||||
},
|
||||
libraries=["c10", "torch", "torch_python"],
|
||||
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
|
||||
),
|
||||
CUDAExtension(
|
||||
"sgl_kernel.ops.moe_align_block_size",
|
||||
[
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
"-gencode=arch=compute_75,code=sm_75",
|
||||
"-gencode=arch=compute_80,code=sm_80",
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
],
|
||||
"cxx": ["-O3"],
|
||||
},
|
||||
libraries=["c10", "torch", "torch_python"],
|
||||
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
|
||||
),
|
||||
],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
install_requires=["torch"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user