diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index fc7a46417..6de25bb01 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -51,6 +51,47 @@ Steps to add a new kernel: 4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source 5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel) +### Development Tips + +1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `` instead of ``. Using `` will cause compilation errors when using SABI. + +2. When creating torch extensions, simply add the function definition with `m.def`: + ```cpp + m.def("register_graph_buffers", register_graph_buffers); + ``` + +3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels. + + **Avoid this:** + + ```cpp + torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( + q=query.view(query.shape[0], -1, head_size), + k=key.view(key.shape[0], -1, head_size), + q_rope=query.view(query.shape[0], -1, head_size), + k_rope=key.view(key.shape[0], -1, head_size), + cos_sin_cache=cos_sin_cache, + pos_ids=positions.long(), + interleave=(not is_neox), + cuda_stream=get_cuda_stream(), + ) + ``` + + **Use this instead:** + + ```cpp + torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( + query.view(query.shape[0], -1, head_size), + key.view(key.shape[0], -1, head_size), + query.view(query.shape[0], -1, head_size), + key.view(key.shape[0], -1, head_size), + cos_sin_cache, + positions.long(), + (not is_neox), + get_cuda_stream(), + ) + ``` + ### Build & Install Development build: diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index f2d5eba73..3633c9f40 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -22,121 +22,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/allreduce */ - m.def( - "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " - "barrier_in, int[] barrier_out) -> int"); - m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - - m.def("dispose", &dispose); - - m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); - m.impl("all_reduce", torch::kCUDA, &all_reduce); - - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); - m.def("register_graph_buffers", ®ister_graph_buffers); + m.def("init_custom_ar", init_custom_ar); + m.def("dispose", dispose); + m.def("all_reduce", all_reduce); + m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta); + m.def("register_graph_buffers", register_graph_buffers); /* * From csrc/attention */ - m.def( - "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " - "new_kv) -> ()"); - m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + m.def("lightning_attention_decode", lightning_attention_decode); /* * From csrc/elementwise */ - m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("rmsnorm", torch::kCUDA, &rmsnorm); - - m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); - m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); - - m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); - - m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - - m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); - - m.def( - "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " - "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); - m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + m.def("rmsnorm", rmsnorm); + m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm); + m.def("gemma_rmsnorm", gemma_rmsnorm); + m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); + m.def("silu_and_mul", silu_and_mul); + m.def("gelu_tanh_and_mul", gelu_tanh_and_mul); + m.def("gelu_and_mul", gelu_and_mul); + m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache); /* * From csrc/gemm */ - m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor"); - m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); - - m.def( - "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " - "bias) -> Tensor"); - m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); - - m.def( - "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " - "bias) -> Tensor"); - m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); - - m.def( - "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> " - "Tensor"); - m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); - - m.def( - "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float fp8_min, float fp8_max) -> ()"); - m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); - - m.def( - "sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float int8_min, float int8_max) -> ()"); - m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8); - - m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); - m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); - - m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); - m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); - - m.def( - "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," - " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); - m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); - - m.def( - "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," - " Tensor block_scale_a, Tensor block_scale_b," - " Tensor alpha) -> ()"); - m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); - - m.def( - "scaled_fp4_quant(Tensor! output, Tensor! input," - " Tensor! output_scale, Tensor! input_scale) -> ()"); - m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + m.def("awq_dequantize", awq_dequantize); + m.def("int8_scaled_mm", int8_scaled_mm); + m.def("fp8_scaled_mm", fp8_scaled_mm); + m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm); + m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8); + m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8); + m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8); + m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8); + m.def("cublas_grouped_gemm", cublas_grouped_gemm); + m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm); + m.def("scaled_fp4_quant", scaled_fp4_quant); /* * From csrc/moe */ - m.def( - "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " - "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); - m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - - m.def( - "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); - m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + m.def("moe_align_block_size", moe_align_block_size); + m.def("topk_softmax", topk_softmax); m.def( "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " @@ -146,62 +74,20 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/speculative */ - m.def( - "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " - "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " - "Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, " - "float threshold_single, float threshold_acc, " - "bool deterministic, int cuda_stream) -> ()"); - m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); - - m.def( - "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " - "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " - "Tensor target_predict, int cuda_stream) -> ()"); - m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); - - m.def( - "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " - "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " - "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); - m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); - - m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"); - m.impl("segment_packbits", torch::kCUDA, &segment_packbits); + m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only); + m.def("verify_tree_greedy", verify_tree_greedy); + m.def("build_tree_kernel_efficient", build_tree_kernel_efficient); + m.def("segment_packbits", segment_packbits); /* * From FlashInfer */ - m.def( - "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " - "cublas_handle, int cuda_stream) -> ()"); - m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); - - m.def( - "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " - "min_p_val, bool deterministic, int cuda_stream) -> ()"); - m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); - - m.def( - "top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " - "cuda_stream) -> ()"); - m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs); - - m.def( - "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " - "cuda_stream) -> ()"); - m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); - - m.def( - "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " - "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " - "cuda_stream) -> ()"); - m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); - - m.def( - "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " - "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); - m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + m.def("bmm_fp8", bmm_fp8); + m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); + m.def("top_k_renorm_probs", top_k_renorm_probs); + m.def("top_p_renorm_probs", top_p_renorm_probs); + m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); + m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 3db157156..307df2a5b 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -142,12 +142,12 @@ def apply_rope_with_cos_sin_cache_inplace( raise ValueError("cos_sin_cache should be float32") torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( - q=query.view(query.shape[0], -1, head_size), - k=key.view(key.shape[0], -1, head_size), - q_rope=query.view(query.shape[0], -1, head_size), - k_rope=key.view(key.shape[0], -1, head_size), - cos_sin_cache=cos_sin_cache, - pos_ids=positions.long(), - interleave=(not is_neox), - cuda_stream=get_cuda_stream(), + query.view(query.shape[0], -1, head_size), + key.view(key.shape[0], -1, head_size), + query.view(query.shape[0], -1, head_size), + key.view(key.shape[0], -1, head_size), + cos_sin_cache, + positions.long(), + (not is_neox), + get_cuda_stream(), )