minor: update sgl-kernel setup (#3107)
This commit is contained in:
@@ -38,10 +38,10 @@ def _get_version():
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
cutlass = root / "3rdparty" / "cutlass"
|
||||
cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
turbomind = root / "3rdparty" / "turbomind"
|
||||
include_dirs = [
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
@@ -49,6 +49,8 @@ include_dirs = [
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "include" / "gemm",
|
||||
flashinfer.resolve() / "csrc",
|
||||
turbomind.resolve(),
|
||||
turbomind.resolve() / "src",
|
||||
]
|
||||
nvcc_flags = [
|
||||
"-DNDEBUG",
|
||||
@@ -63,6 +65,11 @@ nvcc_flags = [
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
]
|
||||
nvcc_flags_fp8 = [
|
||||
"-DFLASHINFER_ENABLE_FP8",
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||
]
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
@@ -73,6 +80,7 @@ sources = [
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/group_gemm.cu",
|
||||
@@ -92,13 +100,7 @@ if torch.cuda.is_available():
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
|
||||
if sm_version >= 90:
|
||||
nvcc_flags.extend(
|
||||
[
|
||||
"-DFLASHINFER_ENABLE_FP8",
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||
]
|
||||
)
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if sm_version >= 80:
|
||||
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||
else:
|
||||
@@ -107,13 +109,7 @@ else:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
|
||||
if enable_fp8:
|
||||
nvcc_flags.extend(
|
||||
[
|
||||
"-DFLASHINFER_ENABLE_FP8",
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||
]
|
||||
)
|
||||
nvcc_flags.extend(nvcc_flags_fp8)
|
||||
if enable_bf16:
|
||||
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user