diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 31360c0a0..0c29322a4 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -41,6 +41,7 @@ jobs: - name: Install run: | pip3 install torch==2.5.1 + pip3 install pytest pip3 uninstall sgl-kernel -y || true cd sgl-kernel pip3 install . diff --git a/.gitignore b/.gitignore index 91966c664..75e29fac3 100644 --- a/.gitignore +++ b/.gitignore @@ -225,3 +225,5 @@ compile_commands.json # VSCode .vscode + +1 diff --git a/.gitmodules b/.gitmodules index c584a21e8..ed7603bfd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "sgl-kernel/3rdparty/cccl"] path = sgl-kernel/3rdparty/cccl url = https://github.com/NVIDIA/cccl.git +[submodule "sgl-kernel/3rdparty/flashinfer"] + path = sgl-kernel/3rdparty/flashinfer + url = https://github.com/flashinfer-ai/flashinfer.git diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer new file mode 160000 index 000000000..a0e99a3a8 --- /dev/null +++ b/sgl-kernel/3rdparty/flashinfer @@ -0,0 +1 @@ +Subproject commit a0e99a3a820109763d9a757138a5cdf7bbcd1f85 diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt new file mode 100644 index 000000000..c930aa5dd --- /dev/null +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -0,0 +1,225 @@ +Notice for flashinfer-ai/flashinfer +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------------------- +Some of the code in this project are adapted from other open-source projects with different +licenses. This product also bundles some third-party components under other open source licenses. +This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +BSD 3-Clause License +-------------------- + +include/flashinfer/attention/hopper/epilogue.cuh +include/flashinfer/attention/hopper/mainloop.cuh +include/flashinfer/attention/hopper/kernel_traits.cuh +include/flashinfer/attention/hopper/named_barrier.cuh +include/flashinfer/attention/hopper/tile_scheduler.cuh +include/flashinfer/attention/hopper/utils.cuh + +BSD 3-Clause "New" License +-------------------------- + +3rdparty/cutlass +include/flashinfer/attention/hopper/block_sparse_gather.cuh diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 9f9867113..a8d9517bb 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,5 +1,6 @@ from pathlib import Path +import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension @@ -24,10 +25,13 @@ def update_wheel_platform_tag(): cutlass = root / "3rdparty" / "cutlass" +flashinfer = root / "3rdparty" / "flashinfer" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", root / "src" / "sgl-kernel" / "csrc", + flashinfer.resolve() / "include", + flashinfer.resolve() / "csrc", ] nvcc_flags = [ "-DNDEBUG", @@ -39,9 +43,21 @@ nvcc_flags = [ "-gencode=arch=compute_89,code=sm_89", "-gencode=arch=compute_90,code=sm_90", "-gencode=arch=compute_90a,code=sm_90a", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", + "-std=c++17", + "-use_fast_math", + "-DFLASHINFER_ENABLE_F16", + "-DFLASHINFER_ENABLE_BF16", ] +for flag in [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", +]: + try: + torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass cxx_flags = ["-O3"] libraries = ["c10", "torch", "torch_python", "cuda"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] @@ -56,6 +72,7 @@ ext_modules = [ "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", + "src/sgl-kernel/csrc/norm.cu", ], include_dirs=include_dirs, extra_compile_args={ diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 480bec71f..3352abeb5 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -6,6 +6,7 @@ from sgl_kernel.ops import ( int8_scaled_mm, moe_align_block_size, register_graph_buffers, + rmsnorm, rotary_embedding, sampling_scaling_penalties, ) @@ -20,4 +21,5 @@ __all__ = [ "get_graph_buffer_ipc_meta", "register_graph_buffers", "rotary_embedding", + "rmsnorm", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/norm.cu b/sgl-kernel/src/sgl-kernel/csrc/norm.cu new file mode 100644 index 000000000..ad102a50d --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/norm.cu @@ -0,0 +1,28 @@ +#include +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) { + CHECK_INPUT(input); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + CHECK_EQ(output.size(0), batch_size); + CHECK_EQ(output.size(1), hidden_size); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = norm::RMSNorm(static_cast(input.data_ptr()), static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, hidden_size, eps, stream); + TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index f2ae95d7f..ed359bfbb 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -30,6 +30,9 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); +// rms norm +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -45,4 +48,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); // rotary embedding m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); + // rms norm + m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index b8abd57d3..e9eadb759 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,3 +1,6 @@ +from typing import Optional + +import torch from sgl_kernel.ops._kernels import all_reduce as _all_reduce from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import ( @@ -7,6 +10,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers +from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, @@ -76,3 +80,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) + + +def rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is None: + out = torch.empty_like(input) + stream = torch.cuda.current_stream().cuda_stream + stream_int = int(stream) + _rmsnorm(out, input, weight, eps, stream_int) + return out diff --git a/sgl-kernel/tests/test_rmsnorm.py b/sgl-kernel/tests/test_rmsnorm.py new file mode 100644 index 000000000..dda225de9 --- /dev/null +++ b/sgl-kernel/tests/test_rmsnorm.py @@ -0,0 +1,31 @@ +import pytest +import torch +from sgl_kernel import rmsnorm + + +def llama_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * w.float() + x = x.to(orig_dtype) + return x + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = llama_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + rmsnorm(x, w, out=y) + else: + y = rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)