minor: update sgl-kernel setup (#3107)

This commit is contained in:
Yineng Zhang
2025-01-24 20:10:35 +08:00
committed by GitHub
parent 4505a43614
commit 04f0b4cbef
2 changed files with 103 additions and 15 deletions

View File

@@ -38,10 +38,10 @@ def _get_version():
return line.split("=")[1].strip().strip('"')
cutlass = root / "3rdparty" / "cutlass"
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
turbomind = root / "3rdparty" / "turbomind"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
@@ -49,6 +49,8 @@ include_dirs = [
flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc",
turbomind.resolve(),
turbomind.resolve() / "src",
]
nvcc_flags = [
"-DNDEBUG",
@@ -63,6 +65,11 @@ nvcc_flags = [
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
]
nvcc_flags_fp8 = [
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
sources = [
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
@@ -73,6 +80,7 @@ sources = [
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/group_gemm.cu",
@@ -92,13 +100,7 @@ if torch.cuda.is_available():
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
if sm_version >= 90:
nvcc_flags.extend(
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
nvcc_flags.extend(nvcc_flags_fp8)
if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
@@ -107,13 +109,7 @@ else:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
if enable_fp8:
nvcc_flags.extend(
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
nvcc_flags.extend(nvcc_flags_fp8)
if enable_bf16:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")