feat: add flashinfer as 3rdparty and use rmsnorm as example (#3033)

This commit is contained in:
Yineng Zhang
2025-01-21 20:44:49 +08:00
committed by GitHub
parent a4331cd260
commit 5a0d680a14
11 changed files with 335 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
from pathlib import Path
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
@@ -24,10 +25,13 @@ def update_wheel_platform_tag():
cutlass = root / "3rdparty" / "cutlass"
flashinfer = root / "3rdparty" / "flashinfer"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "csrc",
flashinfer.resolve() / "include",
flashinfer.resolve() / "csrc",
]
nvcc_flags = [
"-DNDEBUG",
@@ -39,9 +43,21 @@ nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-gencode=arch=compute_90a,code=sm_90a",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
]
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]:
try:
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
except ValueError:
pass
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
@@ -56,6 +72,7 @@ ext_modules = [
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/norm.cu",
],
include_dirs=include_dirs,
extra_compile_args={