diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 80f29921f..4d86f6293 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -157,6 +157,7 @@ set(SGL_KERNEL_CUDA_FLAGS "-DCUTLASS_DEBUG_TRACE_LEVEL=0" "--expt-relaxed-constexpr" "--expt-extended-lambda" + # The following flag leads to the CMAKE_BUILD_PARALLEL_LEVEL breaking, # it triggers OOM with low memory host. Extract the threads number to # option named SGL_KERNEL_COMPILE_THREADS, default value 32. @@ -169,7 +170,8 @@ set(SGL_KERNEL_CUDA_FLAGS "-Xcompiler=-Wno-terminate" "-Xcompiler=-Wfatal-errors" "-Xcompiler=-ftemplate-backtrace-limit=1" - "-Xcudafe=--diag_suppress=177" # variable was declared but never referenced + "-Xcudafe=--diag_suppress=177" # variable was declared but never referenced + "-Xcudafe=--diag_suppress=2361" # invalid narrowing conversion from "char" to "signed char" # uncomment to debug # "--ptxas-options=-v" @@ -299,11 +301,12 @@ set(SOURCES "csrc/grammar/apply_token_bitmask_inplace_cuda.cu" + "csrc/mamba/causal_conv1d.cu" + "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/marlin_moe_wna16/ops.cu" - "csrc/mamba/causal_conv1d.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" diff --git a/sgl-kernel/cmake/utils.cmake b/sgl-kernel/cmake/utils.cmake index 0eaa7a61a..8d676e479 100644 --- a/sgl-kernel/cmake/utils.cmake +++ b/sgl-kernel/cmake/utils.cmake @@ -11,11 +11,9 @@ # macro(clear_cuda_arches CUDA_ARCH_FLAGS) # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` - string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS - ${CMAKE_CUDA_FLAGS}) + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS "${CMAKE_CUDA_FLAGS}") # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified # and passed back via the `CUDA_ARCHITECTURES` property. - string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS - ${CMAKE_CUDA_FLAGS}) + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") endmacro() diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 4f95c9138..ad67248a9 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -99,6 +99,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "mult, int offset, int cuda_stream) -> ()"); m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8); + m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"); + m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce); + m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"); + m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k); + /* * From csrc/gemm */ @@ -447,11 +452,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor _ascales, Tensor! _out_feats) -> ()"); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); - m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"); - m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce); - m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"); - m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k); - /* * From csrc/mamba */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index a13af546a..e1ac17de7 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -170,6 +170,9 @@ void downcast_fp8( int64_t offset, int64_t cuda_stream); +void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); +void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); + #ifdef USE_ROCM void gelu_quick(at::Tensor& out, const at::Tensor& input); #endif @@ -743,9 +746,6 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i */ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); -void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); -void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); - /* * From csrc/mamba */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 76c87d30b..9c676ea8b 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -34,11 +34,6 @@ from sgl_kernel.elementwise import ( rmsnorm, silu_and_mul, ) -from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update - -if torch.version.hip is not None: - from sgl_kernel.elementwise import gelu_quick - from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.gemm import ( awq_dequantize, @@ -71,6 +66,7 @@ from sgl_kernel.kvcacheio import ( transfer_kv_per_layer, transfer_kv_per_layer_mla, ) +from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update from sgl_kernel.marlin import ( awq_marlin_moe_repack, awq_marlin_repack, @@ -104,6 +100,9 @@ from sgl_kernel.speculative import ( from sgl_kernel.top_k import fast_topk from sgl_kernel.version import __version__ +if torch.version.hip is not None: + from sgl_kernel.elementwise import gelu_quick + def create_greenctx_stream_by_value(*args, **kwargs): from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl