Fix cu118 group gemm compile issue (#3097)
This commit is contained in:
@@ -62,6 +62,23 @@ nvcc_flags = [
|
|||||||
"-DFLASHINFER_ENABLE_F16",
|
"-DFLASHINFER_ENABLE_F16",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
sources = [
|
||||||
|
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||||
|
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||||
|
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||||
|
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||||
|
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||||
|
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||||
|
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||||
|
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||||
|
"3rdparty/flashinfer/csrc/activation.cu",
|
||||||
|
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||||
|
"3rdparty/flashinfer/csrc/group_gemm.cu",
|
||||||
|
"3rdparty/flashinfer/csrc/norm.cu",
|
||||||
|
"3rdparty/flashinfer/csrc/sampling.cu",
|
||||||
|
"3rdparty/flashinfer/csrc/renorm.cu",
|
||||||
|
]
|
||||||
|
|
||||||
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
|
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
|
||||||
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
|
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
|
||||||
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
|
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
|
||||||
@@ -71,6 +88,7 @@ sm_version = _get_device_sm()
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if cuda_version >= (12, 0) and sm_version >= 90:
|
if cuda_version >= (12, 0) and sm_version >= 90:
|
||||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||||
|
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
|
||||||
if sm_version >= 90:
|
if sm_version >= 90:
|
||||||
nvcc_flags.extend(
|
nvcc_flags.extend(
|
||||||
[
|
[
|
||||||
@@ -85,6 +103,7 @@ else:
|
|||||||
# compilation environment without GPU
|
# compilation environment without GPU
|
||||||
if enable_sm90a:
|
if enable_sm90a:
|
||||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||||
|
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
|
||||||
if enable_fp8:
|
if enable_fp8:
|
||||||
nvcc_flags.extend(
|
nvcc_flags.extend(
|
||||||
[
|
[
|
||||||
@@ -110,26 +129,11 @@ for flag in [
|
|||||||
cxx_flags = ["-O3"]
|
cxx_flags = ["-O3"]
|
||||||
libraries = ["c10", "torch", "torch_python", "cuda"]
|
libraries = ["c10", "torch", "torch_python", "cuda"]
|
||||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||||
|
|
||||||
ext_modules = [
|
ext_modules = [
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="sgl_kernel.ops._kernels",
|
name="sgl_kernel.ops._kernels",
|
||||||
sources=[
|
sources=sources,
|
||||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
|
||||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
|
||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
|
||||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
|
||||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
|
||||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
|
||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
|
||||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/activation.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/group_gemm.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/norm.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/sampling.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/renorm.cu",
|
|
||||||
],
|
|
||||||
include_dirs=include_dirs,
|
include_dirs=include_dirs,
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"nvcc": nvcc_flags,
|
"nvcc": nvcc_flags,
|
||||||
|
|||||||
Reference in New Issue
Block a user