diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml index 495bf68c8..f589119e6 100644 --- a/.github/workflows/release-pypi-kernel.yml +++ b/.github/workflows/release-pypi-kernel.yml @@ -5,7 +5,7 @@ on: branches: - main paths: - - sgl-kernel/src/sgl-kernel/version.py + - sgl-kernel/python/sgl_kernel/version.py workflow_dispatch: concurrency: diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml index 5eaa0127f..631551475 100644 --- a/.github/workflows/release-whl-kernel.yml +++ b/.github/workflows/release-whl-kernel.yml @@ -9,7 +9,7 @@ on: branches: - main paths: - - sgl-kernel/src/sgl-kernel/version.py + - sgl-kernel/python/sgl_kernel/version.py jobs: build-wheels: @@ -59,7 +59,7 @@ jobs: id: set_tag_name run: | if [ -z "${{ inputs.tag_name }}" ]; then - TAG_NAME="v$(cat sgl-kernel/src/sgl-kernel/version.py | cut -d'"' -f2)" + TAG_NAME="v$(cat sgl-kernel/python/sgl_kernel/version.py | cut -d'"' -f2)" echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT else echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index c5056ffc2..d06765c3a 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -75,42 +75,42 @@ else: rank: int, full_nvlink: bool, ) -> int: - return sgl_kernel.ops.allreduce.init_custom_ar( + return sgl_kernel.allreduce.init_custom_ar( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.allreduce.all_reduce_reg(fa, inp, out) + sgl_kernel.allreduce.all_reduce_reg(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - sgl_kernel.ops.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) + sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - sgl_kernel.ops.allreduce.dispose(fa) + sgl_kernel.allreduce.dispose(fa) def meta_size() -> int: - return sgl_kernel.ops.allreduce.meta_size() + return sgl_kernel.allreduce.meta_size() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return sgl_kernel.ops.allreduce.register_buffer(fa, t, handles, offsets) + return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return sgl_kernel.ops.allreduce.get_graph_buffer_ipc_meta(fa) + return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - sgl_kernel.ops.allreduce.register_graph_buffers(fa, handles, offsets) + sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return sgl_kernel.ops.allreduce.allocate_meta_buffer(size) + return sgl_kernel.allreduce.allocate_meta_buffer(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return sgl_kernel.ops.allreduce.get_meta_buffer_ipc_handle(inp) + return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) else: # TRTLLM custom allreduce @@ -123,7 +123,7 @@ else: barrier_in: List[int], barrier_out: List[int], ) -> int: - return sgl_kernel.ops.init_custom_reduce( + return sgl_kernel.init_custom_reduce( rank_id, world_size, rank_data_base, @@ -134,15 +134,15 @@ else: ) def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.custom_reduce(fa, inp, out) + sgl_kernel.custom_reduce(fa, inp, out) def dispose(fa: int) -> None: - sgl_kernel.ops.custom_dispose(fa) + sgl_kernel.custom_dispose(fa) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + return sgl_kernel.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[List[int]], offsets: List[List[int]] ) -> None: - sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + sgl_kernel.register_graph_buffers(fa, handles, offsets) diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 986e424f4..53375fa0f 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -38,12 +38,12 @@ test: ## Run all tests format: check-deps ## Format all source files @echo "Formatting source files..." - @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i - @find src tests -name '*.py' | xargs isort - @find src tests -name '*.py' | xargs black + @find csrc tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i + @find python tests -name '*.py' | xargs isort + @find python tests -name '*.py' | xargs black @pre-commit run --all-files -FILES_TO_UPDATE = src/sgl-kernel/version.py \ +FILES_TO_UPDATE = python/sgl_kernel/version.py \ pyproject.toml update: ## Update version numbers across project files. Usage: make update @@ -51,7 +51,7 @@ update: ## Update version numbers across project files. Usage: make update "; \ exit 1; \ fi - @OLD_VERSION=$$(grep "version" src/sgl-kernel/version.py | cut -d '"' -f2); \ + @OLD_VERSION=$$(grep "version" python/sgl_kernel/version.py | cut -d '"' -f2); \ NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \ echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \ for file in $(FILES_TO_UPDATE); do \ diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 1f805cbd0..689f34be0 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -45,12 +45,11 @@ Third-party libraries: Steps to add a new kernel: -1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) -2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h) -3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) -4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) -5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) -6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc) +2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h) +3. Create torch extension in [csrc/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/torch_extension.cc) +4. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel) ### Build & Install @@ -72,4 +71,4 @@ The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, t ### Release new version -Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py) +Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py) diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip b/sgl-kernel/csrc/allreduce/custom_all_reduce.hip similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip rename to sgl-kernel/csrc/allreduce/custom_all_reduce.hip diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh b/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh rename to sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu b/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu rename to sgl-kernel/csrc/allreduce/trt_reduce_internal.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu b/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu rename to sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu b/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu rename to sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h rename to sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp rename to sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp rename to sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp rename to sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h rename to sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h rename to sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu b/sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu rename to sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu rename to sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu rename to sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu rename to sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu rename to sgl-kernel/csrc/gemm/int8_gemm_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu rename to sgl-kernel/csrc/gemm/per_token_quant_fp8.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu rename to sgl-kernel/csrc/moe/moe_align_kernel.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu rename to sgl-kernel/csrc/speculative/eagle_utils.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu b/sgl-kernel/csrc/speculative/speculative_sampling.cu similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu rename to sgl-kernel/csrc/speculative/speculative_sampling.cu diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/csrc/speculative/speculative_sampling.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh rename to sgl-kernel/csrc/speculative/speculative_sampling.cuh diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc similarity index 98% rename from sgl-kernel/src/sgl-kernel/torch_extension.cc rename to sgl-kernel/csrc/torch_extension.cc index a8ee87707..9fd32bf99 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -16,33 +16,9 @@ limitations under the License. #include #include -#include "sgl_kernels_ops.h" - -TORCH_LIBRARY_EXPAND(sgl_kernels, m) { - /* - * From csrc/activation - */ - m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("rmsnorm", torch::kCUDA, &rmsnorm); - - m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); - m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); - - m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); - - m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - - m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); - m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); +#include "sgl_kernel_ops.h" +TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/allreduce */ @@ -67,6 +43,30 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { */ m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + /* + * From csrc/elementwise + */ + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + /* * From csrc/gemm */ @@ -93,6 +93,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); + m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); + m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); + m.def( "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); @@ -171,9 +174,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); - - m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()"); - m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8); } -REGISTER_EXTENSION(_kernels) +REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc similarity index 97% rename from sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc rename to sgl-kernel/csrc/torch_extension_rocm.cc index 95adea90b..014e311cf 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "sgl_kernels_ops.h" +#include "sgl_kernel_ops.h" -TORCH_LIBRARY_EXPAND(sgl_kernels, m) { +TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/allreduce */ diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/include/sgl_kernel_ops.h similarity index 99% rename from sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h rename to sgl-kernel/include/sgl_kernel_ops.h index 5bc5c7083..82412b6e0 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -36,18 +36,6 @@ limitations under the License. using fptr_t = int64_t; -/* - * From csrc/activation - */ -void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); -void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void gemma_fused_add_rmsnorm( - at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); - /* * From csrc/allreduce */ @@ -88,6 +76,30 @@ void register_graph_buffers( fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); #endif +/* + * From csrc/attention + */ +void lightning_attention_decode( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& v, + const torch::Tensor& past_kv, + const torch::Tensor& slope, + torch::Tensor output, + torch::Tensor new_kv); + +/* + * From csrc/elementwise + */ +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +void gemma_fused_add_rmsnorm( + at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + /* * From csrc/gemm */ @@ -120,6 +132,7 @@ void sgl_per_token_group_quant_fp8( double fp8_min, double fp8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); +void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void cublas_grouped_gemm( const std::vector& inputs, const std::vector& weights, @@ -254,18 +267,3 @@ void apply_rope_pos_ids_cos_sin_cache( at::Tensor pos_ids, bool interleave, int64_t cuda_stream); - -/* - * Other - */ -void lightning_attention_decode( - const torch::Tensor& q, - const torch::Tensor& k, - const torch::Tensor& v, - const torch::Tensor& past_kv, - const torch::Tensor& slope, - torch::Tensor output, - torch::Tensor new_kv); - -// sgl_per_token_quant_fp8 -void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/include/trt_reduce_internal.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh rename to sgl-kernel/include/trt_reduce_internal.cuh diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h similarity index 100% rename from sgl-kernel/src/sgl-kernel/include/utils.h rename to sgl-kernel/include/utils.h diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 24325aeca..6c7eb3e60 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -20,10 +20,6 @@ dependencies = [] "Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" -[tool.setuptools] -package-dir = {"sgl_kernel" = "src/sgl-kernel"} -packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"] - [tool.wheel] exclude = [ "dist*", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py similarity index 74% rename from sgl-kernel/src/sgl-kernel/__init__.py rename to sgl-kernel/python/sgl_kernel/__init__.py index ab7f673b0..c8cb0443d 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -9,7 +9,10 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): mode=ctypes.RTLD_GLOBAL, ) -from sgl_kernel.ops.activation import ( +from sgl_kernel import common_ops +from sgl_kernel.allreduce import * +from sgl_kernel.attention import lightning_attention_decode +from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, fused_add_rmsnorm, gelu_and_mul, @@ -19,9 +22,7 @@ from sgl_kernel.ops.activation import ( rmsnorm, silu_and_mul, ) -from sgl_kernel.ops.allreduce import * -from sgl_kernel.ops.attention import lightning_attention_decode -from sgl_kernel.ops.gemm import ( +from sgl_kernel.gemm import ( bmm_fp8, cublas_grouped_gemm, fp8_blockwise_scaled_mm, @@ -31,15 +32,15 @@ from sgl_kernel.ops.gemm import ( sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8, ) -from sgl_kernel.ops.moe import moe_align_block_size -from sgl_kernel.ops.sampling import ( +from sgl_kernel.moe import moe_align_block_size +from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_renorm_prob, top_k_top_p_sampling_from_probs, top_p_renorm_prob, top_p_sampling_from_probs, ) -from sgl_kernel.ops.speculative import ( +from sgl_kernel.speculative import ( build_tree_kernel, build_tree_kernel_efficient, tree_speculative_sampling_target_only, diff --git a/sgl-kernel/src/sgl-kernel/ops/allreduce.py b/sgl-kernel/python/sgl_kernel/allreduce.py similarity index 62% rename from sgl-kernel/src/sgl-kernel/ops/allreduce.py rename to sgl-kernel/python/sgl_kernel/allreduce.py index 05079e3f4..0924e7f35 100644 --- a/sgl-kernel/src/sgl-kernel/ops/allreduce.py +++ b/sgl-kernel/python/sgl_kernel/allreduce.py @@ -1,6 +1,5 @@ from typing import List, Tuple -import sgl_kernel.ops._kernels import torch if torch.version.hip is not None: @@ -13,49 +12,49 @@ if torch.version.hip is not None: rank: int, full_nvlink: bool, ) -> int: - return torch.ops.sgl_kernels.init_custom_ar( + return torch.ops.sgl_kernel.init_custom_ar( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out) + torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out) + torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - torch.ops.sgl_kernels.dispose(fa) + torch.ops.sgl_kernel.dispose(fa) def meta_size() -> int: - return torch.ops.sgl_kernels.meta_size() + return torch.ops.sgl_kernel.meta_size() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets) + return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return torch.ops.sgl_kernels.allocate_meta_buffer(size) + return torch.ops.sgl_kernel.allocate_meta_buffer(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp) + return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp) else: # TRTLLM custom allreduce def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): - return torch.ops.sgl_kernels.init_custom_ar( + return torch.ops.sgl_kernel.init_custom_ar( rank_id, num_devices, rank_data, @@ -66,13 +65,13 @@ else: ) def custom_dispose(fa): - torch.ops.sgl_kernels.dispose(fa) + torch.ops.sgl_kernel.dispose(fa) def custom_reduce(fa, inp, out): - torch.ops.sgl_kernels.all_reduce(fa, inp, out) + torch.ops.sgl_kernel.all_reduce(fa, inp, out) def get_graph_buffer_ipc_meta(fa): - return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa) def register_graph_buffers(fa, handles, offsets): - torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets) diff --git a/sgl-kernel/src/sgl-kernel/ops/attention.py b/sgl-kernel/python/sgl_kernel/attention.py similarity index 62% rename from sgl-kernel/src/sgl-kernel/ops/attention.py rename to sgl-kernel/python/sgl_kernel/attention.py index a4cb5fc0b..53fec4dd1 100644 --- a/sgl-kernel/src/sgl-kernel/ops/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -1,8 +1,7 @@ -import sgl_kernel.ops._kernels import torch def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): - torch.ops.sgl_kernels.lightning_attention_decode( + torch.ops.sgl_kernel.lightning_attention_decode( q, k, v, past_kv, slope, output, new_kv ) diff --git a/sgl-kernel/src/sgl-kernel/ops/activation.py b/sgl-kernel/python/sgl_kernel/elementwise.py similarity index 87% rename from sgl-kernel/src/sgl-kernel/ops/activation.py rename to sgl-kernel/python/sgl_kernel/elementwise.py index 08a65ec01..fc6d8ea00 100644 --- a/sgl-kernel/src/sgl-kernel/ops/activation.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,8 +1,7 @@ from typing import Optional -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import get_cuda_stream +from sgl_kernel.utils import get_cuda_stream # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer @@ -15,14 +14,14 @@ def rmsnorm( ) -> torch.Tensor: if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream()) return out def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps) + torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps) def gemma_rmsnorm( @@ -33,14 +32,14 @@ def gemma_rmsnorm( ) -> torch.Tensor: if out is None: out = torch.empty_like(input) - torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) + torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream()) return out def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + torch.ops.sgl_kernel.gemma_fused_add_rmsnorm( input, residual, weight, eps, get_cuda_stream() ) @@ -66,7 +65,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream()) return out @@ -81,7 +80,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream()) return out @@ -96,7 +95,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream()) return out @@ -141,7 +140,7 @@ def apply_rope_with_cos_sin_cache_inplace( raise ValueError("cos_sin_cache should be float32") positions = positions.int() - torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( + torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), diff --git a/sgl-kernel/src/sgl-kernel/ops/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py similarity index 81% rename from sgl-kernel/src/sgl-kernel/ops/gemm.py rename to sgl-kernel/python/sgl_kernel/gemm.py index 883894e96..e5936da56 100644 --- a/sgl-kernel/src/sgl-kernel/ops/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -1,12 +1,11 @@ from typing import List, Optional -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream +from sgl_kernel.utils import _get_cache_buf, get_cuda_stream def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return torch.ops.sgl_kernels.int8_scaled_mm( + return torch.ops.sgl_kernel.int8_scaled_mm( mat_a, mat_b, scales_a, @@ -17,7 +16,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): - return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm( + return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm( mat_a, mat_b, scales_a, @@ -27,7 +26,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return torch.ops.sgl_kernels.fp8_scaled_mm( + return torch.ops.sgl_kernel.fp8_scaled_mm( mat_a, mat_b, scales_a, @@ -46,7 +45,7 @@ def _bmm_fp8_internal( B_scale: torch.Tensor, ) -> None: cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernels.bmm_fp8( + torch.ops.sgl_kernel.bmm_fp8( A, B, D, @@ -86,7 +85,7 @@ def sgl_per_token_group_quant_fp8( fp8_min: float, fp8_max: float, ) -> None: - torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8( + torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8( input, output_q, output_s, group_size, eps, fp8_min, fp8_max ) @@ -97,7 +96,7 @@ def sgl_per_tensor_quant_fp8( output_s: torch.Tensor, is_static: bool, ) -> None: - torch.ops.sgl_kernels.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) + torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static) def cublas_grouped_gemm( @@ -110,7 +109,7 @@ def cublas_grouped_gemm( len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 ), "Inputs/weights/outputs should not be empty!" cublas_handle = torch.cuda.current_blas_handle() - torch.ops.sgl_kernels.cublas_grouped_gemm( + torch.ops.sgl_kernel.cublas_grouped_gemm( inputs, weights, outputs, @@ -125,4 +124,4 @@ def sgl_per_token_quant_fp8( output_q: torch.Tensor, output_s: torch.Tensor, ) -> None: - torch.ops.sgl_kernels.sgl_per_token_quant_fp8(input, output_q, output_s) + torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s) diff --git a/sgl-kernel/src/sgl-kernel/ops/moe.py b/sgl-kernel/python/sgl_kernel/moe.py similarity index 83% rename from sgl-kernel/src/sgl-kernel/ops/moe.py rename to sgl-kernel/python/sgl_kernel/moe.py index 208198272..ad20da036 100644 --- a/sgl-kernel/src/sgl-kernel/ops/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -1,4 +1,3 @@ -import sgl_kernel.ops._kernels import torch @@ -12,7 +11,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - torch.ops.sgl_kernels.moe_align_block_size( + torch.ops.sgl_kernel.moe_align_block_size( topk_ids, num_experts, block_size, diff --git a/sgl-kernel/src/sgl-kernel/ops/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py similarity index 94% rename from sgl-kernel/src/sgl-kernel/ops/sampling.py rename to sgl-kernel/python/sgl_kernel/sampling.py index 1be42f8fd..7bf10bd4a 100644 --- a/sgl-kernel/src/sgl-kernel/ops/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -1,8 +1,7 @@ from typing import Optional, Tuple, Union -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream +from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream def _top_k_renorm_probs_internal( @@ -13,7 +12,7 @@ def _top_k_renorm_probs_internal( probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( + torch.ops.sgl_kernel.top_k_renorm_probs_wrapper( probs, renorm_probs, maybe_top_k_arr, @@ -41,7 +40,7 @@ def _top_p_renorm_probs_internal( probs = probs.float() maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernels.top_p_renorm_probs( + torch.ops.sgl_kernel.top_p_renorm_probs( probs, renorm_probs, maybe_top_p_arr, @@ -76,7 +75,7 @@ def _top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - torch.ops.sgl_kernels.top_p_sampling_from_probs( + torch.ops.sgl_kernel.top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -122,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( + torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -180,7 +179,7 @@ def _min_p_sampling_from_probs_internal( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - torch.ops.sgl_kernels.min_p_sampling_from_probs( + torch.ops.sgl_kernel.min_p_sampling_from_probs( probs, uniform_samples, samples, diff --git a/sgl-kernel/src/sgl-kernel/ops/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py similarity index 88% rename from sgl-kernel/src/sgl-kernel/ops/speculative.py rename to sgl-kernel/python/sgl_kernel/speculative.py index f209f16a9..53acb1d95 100644 --- a/sgl-kernel/src/sgl-kernel/ops/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -1,6 +1,5 @@ -import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops.utils import get_cuda_stream +from sgl_kernel.utils import get_cuda_stream def tree_speculative_sampling_target_only( @@ -16,7 +15,7 @@ def tree_speculative_sampling_target_only( draft_probs: torch.Tensor, deterministic: bool = True, ) -> None: - torch.ops.sgl_kernels.tree_speculative_sampling_target_only( + torch.ops.sgl_kernel.tree_speculative_sampling_target_only( predicts, accept_index, accept_token_num, @@ -45,7 +44,7 @@ def build_tree_kernel_efficient( depth: int, draft_token_num: int, ) -> None: - torch.ops.sgl_kernels.build_tree_kernel_efficient( + torch.ops.sgl_kernel.build_tree_kernel_efficient( parent_list, selected_index, verified_seq_len, @@ -71,7 +70,7 @@ def build_tree_kernel( depth: int, draft_token_num: int, ) -> None: - torch.ops.sgl_kernels.build_tree_kernel( + torch.ops.sgl_kernel.build_tree_kernel( parent_list, selected_index, verified_seq_len, diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/python/sgl_kernel/utils.py similarity index 100% rename from sgl-kernel/src/sgl-kernel/ops/utils.py rename to sgl-kernel/python/sgl_kernel/utils.py diff --git a/sgl-kernel/src/sgl-kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py similarity index 100% rename from sgl-kernel/src/sgl-kernel/version.py rename to sgl-kernel/python/sgl_kernel/version.py diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 545ff1bfc..72d710b3d 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -48,16 +48,16 @@ def _get_version(): return line.split("=")[1].strip().strip('"') -operator_namespace = "sgl_kernels" +operator_namespace = "sgl_kernel" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" include_dirs = [ + root / "include", + root / "csrc", cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", - root / "src" / "sgl-kernel" / "include", - root / "src" / "sgl-kernel" / "csrc", flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", @@ -96,21 +96,21 @@ nvcc_flags_fp8 = [ ] sources = [ - "src/sgl-kernel/torch_extension.cc", - "src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu", - "src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu", - "src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu", - "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", - "src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", - "src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", - "src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu", - "src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu", - "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", - "src/sgl-kernel/csrc/speculative/eagle_utils.cu", - "src/sgl-kernel/csrc/speculative/speculative_sampling.cu", + "csrc/allreduce/trt_reduce_internal.cu", + "csrc/allreduce/trt_reduce_kernel.cu", + "csrc/attention/lightning_attention_decode_kernel.cu", + "csrc/elementwise/fused_add_rms_norm_kernel.cu", + "csrc/gemm/cublas_grouped_gemm.cu", + "csrc/gemm/fp8_gemm_kernel.cu", + "csrc/gemm/fp8_blockwise_gemm_kernel.cu", + "csrc/gemm/int8_gemm_kernel.cu", + "csrc/gemm/per_token_group_quant_fp8.cu", + "csrc/gemm/per_token_quant_fp8.cu", + "csrc/gemm/per_tensor_quant_fp8.cu", + "csrc/moe/moe_align_kernel.cu", + "csrc/speculative/eagle_utils.cu", + "csrc/speculative/speculative_sampling.cu", + "csrc/torch_extension.cc", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", @@ -158,7 +158,7 @@ extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linu ext_modules = [ CUDAExtension( - name="sgl_kernel.ops._kernels", + name="sgl_kernel.common_ops", sources=sources, include_dirs=include_dirs, extra_compile_args={ @@ -174,8 +174,8 @@ ext_modules = [ setup( name="sgl-kernel", version=_get_version(), - packages=find_packages(), - package_dir={"": "src"}, + packages=find_packages(where="python"), + package_dir={"": "python"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, options={"bdist_wheel": {"py_limited_api": "cp39"}}, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 9185e4ae1..25484ae7a 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -13,12 +13,9 @@ # limitations under the License. # ============================================================================== -import multiprocessing -import os import sys from pathlib import Path -import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension @@ -35,16 +32,16 @@ def _get_version(): return line.split("=")[1].strip().strip('"') -operator_namespace = "sgl_kernels" +operator_namespace = "sgl_kernel" include_dirs = [ - root / "src" / "sgl-kernel" / "include", - root / "src" / "sgl-kernel" / "csrc", + root / "include", + root / "csrc", ] sources = [ - "src/sgl-kernel/torch_extension_rocm.cc", - "src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip", - "src/sgl-kernel/csrc/moe/moe_align_kernel.cu", + "csrc/allreduce/custom_all_reduce.hip", + "csrc/moe/moe_align_kernel.cu", + "csrc/torch_extension_rocm.cc", ] cxx_flags = ["-O3"] @@ -64,26 +61,27 @@ hipcc_flags = [ "-DENABLE_FP8", ] +ext_modules = [ + CUDAExtension( + name="sgl_kernel.common_ops", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": hipcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), +] + setup( name="sgl-kernel", version=_get_version(), packages=find_packages(), - package_dir={"": "src"}, - ext_modules=[ - CUDAExtension( - name="sgl_kernel.ops._kernels", - sources=sources, - include_dirs=include_dirs, - extra_compile_args={ - "nvcc": hipcc_flags, - "cxx": cxx_flags, - }, - libraries=libraries, - extra_link_args=extra_link_args, - py_limited_api=True, - ), - ], + package_dir={"": "python"}, + ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, options={"bdist_wheel": {"py_limited_api": "cp39"}}, - install_requires=["torch"], ) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 0387637ab..9bbc4e76f 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -7,7 +7,7 @@ import unittest from typing import Any, List, Optional import ray -import sgl_kernel.ops.allreduce as custom_ops +import sgl_kernel.allreduce as custom_ops import torch import torch.distributed as dist from torch.distributed import ProcessGroup