feat: add flashinfer as 3rdparty and use rmsnorm as example (#3033)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
@@ -24,10 +25,13 @@ def update_wheel_platform_tag():
|
||||
|
||||
|
||||
cutlass = root / "3rdparty" / "cutlass"
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
include_dirs = [
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "csrc",
|
||||
]
|
||||
nvcc_flags = [
|
||||
"-DNDEBUG",
|
||||
@@ -39,9 +43,21 @@ nvcc_flags = [
|
||||
"-gencode=arch=compute_89,code=sm_89",
|
||||
"-gencode=arch=compute_90,code=sm_90",
|
||||
"-gencode=arch=compute_90a,code=sm_90a",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-std=c++17",
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
"-DFLASHINFER_ENABLE_BF16",
|
||||
]
|
||||
for flag in [
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
]:
|
||||
try:
|
||||
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
|
||||
except ValueError:
|
||||
pass
|
||||
cxx_flags = ["-O3"]
|
||||
libraries = ["c10", "torch", "torch_python", "cuda"]
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||
@@ -56,6 +72,7 @@ ext_modules = [
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"src/sgl-kernel/csrc/norm.cu",
|
||||
],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
|
||||
Reference in New Issue
Block a user