Fix the style of sgl kernel (#10398)
This commit is contained in:
@@ -157,6 +157,7 @@ set(SGL_KERNEL_CUDA_FLAGS
|
|||||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||||
"--expt-relaxed-constexpr"
|
"--expt-relaxed-constexpr"
|
||||||
"--expt-extended-lambda"
|
"--expt-extended-lambda"
|
||||||
|
|
||||||
# The following flag leads to the CMAKE_BUILD_PARALLEL_LEVEL breaking,
|
# The following flag leads to the CMAKE_BUILD_PARALLEL_LEVEL breaking,
|
||||||
# it triggers OOM with low memory host. Extract the threads number to
|
# it triggers OOM with low memory host. Extract the threads number to
|
||||||
# option named SGL_KERNEL_COMPILE_THREADS, default value 32.
|
# option named SGL_KERNEL_COMPILE_THREADS, default value 32.
|
||||||
@@ -169,7 +170,8 @@ set(SGL_KERNEL_CUDA_FLAGS
|
|||||||
"-Xcompiler=-Wno-terminate"
|
"-Xcompiler=-Wno-terminate"
|
||||||
"-Xcompiler=-Wfatal-errors"
|
"-Xcompiler=-Wfatal-errors"
|
||||||
"-Xcompiler=-ftemplate-backtrace-limit=1"
|
"-Xcompiler=-ftemplate-backtrace-limit=1"
|
||||||
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
|
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
|
||||||
|
"-Xcudafe=--diag_suppress=2361" # invalid narrowing conversion from "char" to "signed char"
|
||||||
|
|
||||||
# uncomment to debug
|
# uncomment to debug
|
||||||
# "--ptxas-options=-v"
|
# "--ptxas-options=-v"
|
||||||
@@ -299,11 +301,12 @@ set(SOURCES
|
|||||||
|
|
||||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||||
|
|
||||||
|
"csrc/mamba/causal_conv1d.cu"
|
||||||
|
|
||||||
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||||
"csrc/mamba/causal_conv1d.cu"
|
|
||||||
"csrc/moe/moe_align_kernel.cu"
|
"csrc/moe/moe_align_kernel.cu"
|
||||||
"csrc/moe/moe_fused_gate.cu"
|
"csrc/moe/moe_fused_gate.cu"
|
||||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||||
|
|||||||
@@ -11,11 +11,9 @@
|
|||||||
#
|
#
|
||||||
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
||||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||||
${CMAKE_CUDA_FLAGS})
|
|
||||||
|
|
||||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||||
${CMAKE_CUDA_FLAGS})
|
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|||||||
@@ -99,6 +99,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"mult, int offset, int cuda_stream) -> ()");
|
"mult, int offset, int cuda_stream) -> ()");
|
||||||
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
|
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
|
||||||
|
|
||||||
|
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
|
||||||
|
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
||||||
|
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
||||||
|
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/gemm
|
* From csrc/gemm
|
||||||
*/
|
*/
|
||||||
@@ -447,11 +452,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||||
|
|
||||||
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
|
|
||||||
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
|
||||||
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
|
||||||
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/mamba
|
* From csrc/mamba
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -170,6 +170,9 @@ void downcast_fp8(
|
|||||||
int64_t offset,
|
int64_t offset,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
|
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
||||||
|
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||||
#endif
|
#endif
|
||||||
@@ -743,9 +746,6 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
|||||||
*/
|
*/
|
||||||
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
|
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
|
||||||
|
|
||||||
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
|
||||||
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/mamba
|
* From csrc/mamba
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -34,11 +34,6 @@ from sgl_kernel.elementwise import (
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
silu_and_mul,
|
silu_and_mul,
|
||||||
)
|
)
|
||||||
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
|
|
||||||
|
|
||||||
if torch.version.hip is not None:
|
|
||||||
from sgl_kernel.elementwise import gelu_quick
|
|
||||||
|
|
||||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||||
from sgl_kernel.gemm import (
|
from sgl_kernel.gemm import (
|
||||||
awq_dequantize,
|
awq_dequantize,
|
||||||
@@ -71,6 +66,7 @@ from sgl_kernel.kvcacheio import (
|
|||||||
transfer_kv_per_layer,
|
transfer_kv_per_layer,
|
||||||
transfer_kv_per_layer_mla,
|
transfer_kv_per_layer_mla,
|
||||||
)
|
)
|
||||||
|
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
|
||||||
from sgl_kernel.marlin import (
|
from sgl_kernel.marlin import (
|
||||||
awq_marlin_moe_repack,
|
awq_marlin_moe_repack,
|
||||||
awq_marlin_repack,
|
awq_marlin_repack,
|
||||||
@@ -104,6 +100,9 @@ from sgl_kernel.speculative import (
|
|||||||
from sgl_kernel.top_k import fast_topk
|
from sgl_kernel.top_k import fast_topk
|
||||||
from sgl_kernel.version import __version__
|
from sgl_kernel.version import __version__
|
||||||
|
|
||||||
|
if torch.version.hip is not None:
|
||||||
|
from sgl_kernel.elementwise import gelu_quick
|
||||||
|
|
||||||
|
|
||||||
def create_greenctx_stream_by_value(*args, **kwargs):
|
def create_greenctx_stream_by_value(*args, **kwargs):
|
||||||
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
|
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
|
||||||
|
|||||||
Reference in New Issue
Block a user