From 6b45a21d16a34f23ab2e6ff945987aaa076cfba9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 3 Mar 2025 05:32:30 -0800 Subject: [PATCH] Reorganize c++ source files in sgl-kernel with multiple folders (#4025) --- sgl-kernel/setup.py | 28 +-- sgl-kernel/setup_rocm.py | 4 +- .../fused_add_rms_norm_kernel.cu | 0 .../{ => allreduce}/custom_all_reduce.hip | 0 .../{ => allreduce}/custom_all_reduce_hip.cuh | 0 .../{ => allreduce}/trt_reduce_internal.cu | 0 .../csrc/{ => allreduce}/trt_reduce_kernel.cu | 0 .../csrc/{ => gemm}/cublas_grouped_gemm.cu | 0 .../{ => gemm}/fp8_blockwise_gemm_kernel.cu | 0 .../csrc/{ => gemm}/fp8_gemm_kernel.cu | 0 .../csrc/{ => gemm}/int8_gemm_kernel.cu | 0 .../{ => gemm}/per_token_group_quant_fp8.cu | 0 .../csrc/{ => moe}/moe_align_kernel.cu | 0 .../csrc/{ => speculative}/eagle_utils.cu | 0 .../{ => speculative}/speculative_sampling.cu | 2 +- .../speculative_sampling.cuh | 0 .../src/sgl-kernel/include/sgl_kernels_ops.h | 170 ++++++++-------- sgl-kernel/src/sgl-kernel/torch_extension.cc | 185 +++++++++--------- .../src/sgl-kernel/torch_extension_rocm.cc | 10 +- ...st_trt_reduce.py => test_trt_allreduce.py} | 14 +- 20 files changed, 203 insertions(+), 210 deletions(-) rename sgl-kernel/src/sgl-kernel/csrc/{ => activation}/fused_add_rms_norm_kernel.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => allreduce}/custom_all_reduce.hip (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => allreduce}/custom_all_reduce_hip.cuh (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => allreduce}/trt_reduce_internal.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => allreduce}/trt_reduce_kernel.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => gemm}/cublas_grouped_gemm.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => gemm}/fp8_blockwise_gemm_kernel.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => gemm}/fp8_gemm_kernel.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => gemm}/int8_gemm_kernel.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => gemm}/per_token_group_quant_fp8.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => moe}/moe_align_kernel.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => speculative}/eagle_utils.cu (100%) rename sgl-kernel/src/sgl-kernel/csrc/{ => speculative}/speculative_sampling.cu (99%) rename sgl-kernel/src/sgl-kernel/csrc/{ => speculative}/speculative_sampling.cuh (100%) rename sgl-kernel/tests/{test_trt_reduce.py => test_trt_allreduce.py} (97%) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 532f90601..2c9a8d089 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -80,6 +80,12 @@ nvcc_flags = [ "-std=c++17", "-use_fast_math", "-DFLASHINFER_ENABLE_F16", + "-DCUTLASS_VERSIONS_GENERATED", + "-DCUTE_USE_PACKED_TUPLE=1", + "-DCUTLASS_TEST_LEVEL=0", + "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "--ptxas-options=-v", "-Xcompiler=-Wconversion", "-Xcompiler=-fno-strict-aliasing", ] @@ -91,18 +97,18 @@ nvcc_flags_fp8 = [ sources = [ "src/sgl-kernel/torch_extension.cc", - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/int8_gemm_kernel.cu", - "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", - "src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu", + "src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu", + "src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu", + "src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu", + "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", + "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", + "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", + "src/sgl-kernel/csrc/speculative/eagle_utils.cu", + "src/sgl-kernel/csrc/speculative/speculative_sampling.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", - "src/sgl-kernel/csrc/eagle_utils.cu", - "src/sgl-kernel/csrc/speculative_sampling.cu", - "src/sgl-kernel/csrc/per_token_group_quant_fp8.cu", - "src/sgl-kernel/csrc/cublas_grouped_gemm.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index e3ad6c546..89e6ed9ac 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -43,8 +43,8 @@ include_dirs = [ sources = [ "src/sgl-kernel/torch_extension_rocm.cc", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/custom_all_reduce.hip", + "src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip", + "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", ] cxx_flags = ["-O3"] diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu rename to sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip b/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip rename to sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip diff --git a/sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh b/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh rename to sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu rename to sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu rename to sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/cublas_grouped_gemm.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cublas_grouped_gemm.cu rename to sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu rename to sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu rename to sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu rename to sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu rename to sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu rename to sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu rename to sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu similarity index 99% rename from sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu rename to sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu index a24945510..379a2a22c 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu @@ -14,9 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "pytorch_extension_utils.h" +#include "speculative_sampling.cuh" using namespace flashinfer; diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh rename to sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 6e3eab1af..fcc2c6139 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -35,7 +35,24 @@ limitations under the License. } using fptr_t = int64_t; + +/* + * From csrc/activation + */ +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +/* + * From csrc/allreduce + */ #ifdef USE_ROCM +// ROCM custom allreduce fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, const std::vector& offsets, int64_t rank, bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); @@ -50,7 +67,7 @@ void register_graph_buffers(fptr_t _fa, const std::vector& handles, torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); #else -// trt_reduce +// TRTLLM custom allreduce fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, const std::vector& tmp_result_buffers, const std::vector& barrier_in, const std::vector& barrier_out); @@ -61,94 +78,34 @@ void register_graph_buffers(fptr_t _fa, const std::vector>& const std::vector>& offsets); #endif -// moe_align_block_size +/* + * From csrc/gemm + */ +torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); +torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Dtype& out_dtype); +void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, + double eps, double fp8_min, double fp8_max); +void cublas_grouped_gemm(const std::vector& inputs, const std::vector& weights, + const std::vector& outputs, const torch::Dtype& out_dtype, + int64_t cublas_handle, int64_t cuda_stream); + +/* + * From csrc/moe + */ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); -// int8_scaled_mm -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); - -// fp8_scaled_mm -torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); - -// fp8_blockwise_scaled_mm -torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Dtype& out_dtype); - -// lightning_attention_decode -void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, - const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, - torch::Tensor new_kv); - -// rms norm -void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); - -// fused rms norm -void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); - -// gemma rms norm -void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); - -// fused gemma rms norm -void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, - int64_t cuda_stream); - -// silu and mul -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); - -// gelu tanh and mul -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); - -// gelu and mul -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); - -// bmm fp8 -void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, - at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); - -// min p sampling from probs -void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - std::optional maybe_min_p_arr, double min_p_val, bool deterministic, - int64_t cuda_stream); - -// top k renorm probs -// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, - unsigned int top_k_val, int64_t cuda_stream); - -// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -// wrapper for binding -inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, int64_t top_k_val, - int64_t cuda_stream) { - top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); -} - -// top p renorm probs -void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, - double top_p_val, int64_t cuda_stream); - -// top k top p sampling from probs -void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); - -// top p sampling from probs -void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); - -void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, - int64_t cuda_stream); - +/* + * From csrc/speculative + */ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, at::Tensor accept_token_num, // mutable at::Tensor candidates, at::Tensor retrive_index, @@ -165,11 +122,40 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, int64_t depth, int64_t draft_token_num); -// sgl_per_token_group_quant_fp8 -void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, - double eps, double fp8_min, double fp8_max); +/* + * From FlashInfer + */ +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + std::optional maybe_min_p_arr, double min_p_val, bool deterministic, + int64_t cuda_stream); +// top k renorm probs +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, int64_t cuda_stream); +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val, + int64_t cuda_stream) { + top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); +} +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, + double top_p_val, int64_t cuda_stream); +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, + int64_t cuda_stream); -// cublas grouped gemm -void cublas_grouped_gemm(const std::vector& inputs, const std::vector& weights, - const std::vector& outputs, const torch::Dtype& out_dtype, - int64_t cublas_handle, int64_t cuda_stream); +/* + * Other + */ +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv); diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 0585911c4..b71a83cf3 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -19,7 +19,33 @@ limitations under the License. #include "sgl_kernels_ops.h" TORCH_LIBRARY_EXPAND(sgl_kernels, m) { - // trt_reduce + /* + * From csrc/activation + */ + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + /* + * From csrc/allreduce + */ m.def( "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " "barrier_in, int[] barrier_out) -> int"); @@ -36,108 +62,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers); - // moe_align_block_size - m.def( - "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " - "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); - m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - - // int8_scaled_mm + /* + * From csrc/gemm + */ m.def( "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "bias) -> Tensor"); m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); - // fp8_scaled_mm m.def( "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "bias) -> Tensor"); m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); - // fp8_blockwise_scaled_mm m.def( "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> " "Tensor"); m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); - // lightning_attention_decode + m.def( + "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps, float fp8_min, float fp8_max) -> ()"); + m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + + m.def( + "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," + " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); + m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); + + /* + * From csrc/moe + */ + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + m.def( "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " "new_kv) -> ()"); - m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); - // rms norm - m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("rmsnorm", torch::kCUDA, &rmsnorm); - - // fused rms norm - m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); - m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); - - // gemma rms norm - m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); - - // fused gemma rms norm - m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - - // silu and mul - m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - - // gelu tanh and mul - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - // gelu and mul - m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); - - // bmm fp8 - m.def( - "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " - "cublas_handle, int cuda_stream) -> ()"); - m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); - - // min p sampling from probs - m.def( - "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " - "min_p_val, bool deterministic, int cuda_stream) -> ()"); - m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); - - // top k renorm probs - m.def( - "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " - "cuda_stream) -> ()"); - m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); - - // top p renorm probs - m.def( - "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " - "cuda_stream) -> ()"); - m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); - - // top k top p sampling from probs - m.def( - "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " - "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " - "cuda_stream) -> ()"); - m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); - - // top p sampling from probs - m.def( - "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " - "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); - m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); - - // apply rope with cos sin cache - m.def( - "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " - "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); - m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); - - // tree spec decode + /* + * From csrc/speculative + */ m.def( "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " @@ -145,7 +112,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "bool deterministic, int cuda_stream) -> ()"); m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); - // eagle build tree m.def( "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! " @@ -153,24 +119,55 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "int topk, int depth, int draft_token_num) -> ()"); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); - // eagle build tree m.def( "build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, " "int topk, int depth, int draft_token_num) -> ()"); m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel); - // per_token_group_quant_fp8 + /* + * From FlashInfer + */ m.def( - "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," - " float eps, float fp8_min, float fp8_max) -> ()"); - m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); - // cublas grouped gemm m.def( - "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," - " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); - m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + m.def( + "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + + m.def( + "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " + "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + + /* + * Other + */ + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); } REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc index 2c41bb57e..95adea90b 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc @@ -19,7 +19,9 @@ limitations under the License. #include "sgl_kernels_ops.h" TORCH_LIBRARY_EXPAND(sgl_kernels, m) { - // Custom all-reduce kernels + /* + * From csrc/allreduce + */ m.def( "init_custom_ar(Tensor meta, Tensor rank_data, " "str[] handles, int[] offsets, int rank, " @@ -45,12 +47,16 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); m.def("register_graph_buffers", ®ister_graph_buffers); + m.def("allocate_meta_buffer", &allocate_meta_buffer); m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer); + m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); - // moe_align_block_size + /* + * From csrc/moe + */ m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_allreduce.py similarity index 97% rename from sgl-kernel/tests/test_trt_reduce.py rename to sgl-kernel/tests/test_trt_allreduce.py index b79580070..caf92183d 100644 --- a/sgl-kernel/tests/test_trt_reduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -1,11 +1,10 @@ import ctypes import logging -import os import random import socket import time import unittest -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import ray import torch @@ -115,7 +114,7 @@ class TestCustomAllReduce(unittest.TestCase): self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") ) self.custom_ptr = custom_ops.init_custom_reduce( @@ -148,7 +147,7 @@ class TestCustomAllReduce(unittest.TestCase): self.vllm_max_size, group=group ) self.vllm_rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") ) self.vllm_ptr = vllm_ops.init_custom_ar( self.vllm_meta_ptrs, self.vllm_rank_data, rank, True @@ -171,8 +170,7 @@ class TestCustomAllReduce(unittest.TestCase): @staticmethod def init_distributed_env(world_size, rank, distributed_init_port): - del os.environ["CUDA_VISIBLE_DEVICES"] - device = torch.device(f"cuda:{rank}") + device = torch.device("cuda:0") torch.cuda.set_device(device) ranks = [i for i in range(world_size)] distributed_init_method = f"tcp://localhost:{distributed_init_port}" @@ -234,8 +232,8 @@ class TestCustomAllReduce(unittest.TestCase): if rank == 0: logger.warning( f"test_size = {sz}, world_size = {world_size}, " - f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms," - f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms" + f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, " + f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms " ) self.free_custom_allreduce(group)