Fix the style of sgl kernel (#10398)
This commit is contained in:
@@ -99,6 +99,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"mult, int offset, int cuda_stream) -> ()");
|
||||
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
|
||||
|
||||
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
|
||||
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
||||
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
||||
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
@@ -447,11 +452,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||
|
||||
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
|
||||
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
||||
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
||||
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
||||
|
||||
/*
|
||||
* From csrc/mamba
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user