Minor style fixes for sgl-kernel (#9289)

This commit is contained in:
Lianmin Zheng
2025-08-18 09:38:35 -07:00
committed by GitHub
parent 6e316588f8
commit c480a3f6ea
17 changed files with 439 additions and 109 deletions

View File

@@ -17,6 +17,7 @@ limitations under the License.
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/allreduce
@@ -93,6 +94,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
"mult, int offset, int cuda_stream) -> ()");
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
/*
* From csrc/gemm
*/
@@ -161,7 +167,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// GPTQ related method
/*
* From csrc/gemm/gptq
*/
m.def(
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
@@ -183,6 +191,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
/*
* From csrc/moe
*/
@@ -229,6 +238,41 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From csrc/speculative
*/
@@ -306,25 +350,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From FlashInfer
*/
@@ -358,19 +383,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From Sparse Flash Attention
*/