feat: integrate bmm_fp8 kernel into sgl-kernel (#3056)
This commit is contained in:
@@ -62,12 +62,22 @@ nvcc_flags = [
|
||||
"-std=c++17",
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
"-DFLASHINFER_ENABLE_BF16",
|
||||
]
|
||||
|
||||
if cuda_version >= (12, 0) and sm_version >= 90:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
|
||||
if sm_version >= 90:
|
||||
nvcc_flags.extend(
|
||||
[
|
||||
"-DFLASHINFER_ENABLE_FP8",
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||
]
|
||||
)
|
||||
if sm_version >= 80:
|
||||
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||
|
||||
for flag in [
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
|
||||
Reference in New Issue
Block a user