sync the upstream updates of flashinfer (#3051)
This commit is contained in:
@@ -47,6 +47,7 @@ include_dirs = [
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "include" / "gemm",
|
||||
flashinfer.resolve() / "csrc",
|
||||
]
|
||||
nvcc_flags = [
|
||||
@@ -91,7 +92,12 @@ ext_modules = [
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/group_gemm.cu",
|
||||
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
|
||||
"3rdparty/flashinfer/csrc/norm.cu",
|
||||
"3rdparty/flashinfer/csrc/sampling.cu",
|
||||
],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
|
||||
Reference in New Issue
Block a user