From eb06dbcbf8cb652a3f6c3b0392366b1cfed3515d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 9 Mar 2025 18:38:15 -0700 Subject: [PATCH] Move rope and bmm into sgl-kernel (#4241) --- sgl-kernel/csrc/elementwise/rope.cu | 89 +++++++++++++++++++++++++++++ sgl-kernel/csrc/gemm/bmm_fp8.cu | 76 ++++++++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 18 +++--- sgl-kernel/pyproject.toml | 7 ++- sgl-kernel/setup.py | 6 +- 5 files changed, 183 insertions(+), 13 deletions(-) create mode 100644 sgl-kernel/csrc/elementwise/rope.cu create mode 100644 sgl-kernel/csrc/gemm/bmm_fp8.cu diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu new file mode 100644 index 000000000..49565f6f0 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void apply_rope_pos_ids_cos_sin_cache( + at::Tensor q, + at::Tensor k, + at::Tensor q_rope, + at::Tensor k_rope, + at::Tensor cos_sin_cache, + at::Tensor pos_ids, + bool interleave, + int64_t cuda_stream) { + CHECK_LAST_DIM_CONTIGUOUS(q); + CHECK_LAST_DIM_CONTIGUOUS(k); + CHECK_INPUT(cos_sin_cache); + CHECK_INPUT(pos_ids); + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_EQ(cos_sin_cache.device(), device); + CHECK_EQ(pos_ids.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + // cos_sin_cache: (max_seq_len, R) + // First half of R is cos, second half is sin + CHECK_DIM(2, cos_sin_cache); + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int rotary_dim = cos_sin_cache.size(1); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int nnz = q.size(0); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), + static_cast(k_rope.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), + nnz, + num_qo_heads, + num_kv_heads, + rotary_dim, + head_dim, + q_stride_n, + q_stride_h, + k_stride_n, + k_stride_h, + q_rope_stride_n, + q_rope_stride_h, + k_rope_stride_n, + k_rope_stride_h, + interleave, + stream); + TORCH_CHECK( + status == cudaSuccess, + "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/sgl-kernel/csrc/gemm/bmm_fp8.cu b/sgl-kernel/csrc/gemm/bmm_fp8.cu new file mode 100644 index 000000000..4a82b4b27 --- /dev/null +++ b/sgl-kernel/csrc/gemm/bmm_fp8.cu @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "pytorch_extension_utils.h" + +void bmm_fp8( + at::Tensor A, + at::Tensor B, + at::Tensor D, + at::Tensor A_scale, + at::Tensor B_scale, + at::Tensor workspace_buffer, + int64_t cublas_handle, + int64_t cuda_stream) { + TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); + TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); + TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor"); + TORCH_CHECK(A.dim() == 3, "Expected 3D tensor for A"); + TORCH_CHECK(B.dim() == 3, "Expected 3D tensor for B"); + TORCH_CHECK(D.dim() == 3, "Expected 3D tensor for D"); + TORCH_CHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0), "Batch sizes must match"); + TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes"); + TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2), "Result tensor has incorrect shape"); + + // PyTorch is row major by default. cuBLASLt is column major by default. + // We need row major D as expected. + // A ^ T * B = D, so D ^ T = B ^ T * A + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] { + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); + + auto lt_handle = reinterpret_cast(cublas_handle); + auto stream = reinterpret_cast(cuda_stream); + + auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( + workspace_buffer.data_ptr(), + workspace_buffer.numel(), + static_cast(B.data_ptr()), + static_cast(A.data_ptr()), + static_cast(D.data_ptr()), + batch_size, + n, + m, + k, + static_cast(B_scale.data_ptr()), + static_cast(A_scale.data_ptr()), + lt_handle, + stream); + TORCH_CHECK( + status == CUBLAS_STATUS_SUCCESS, "bmm_fp8_internal_cublaslt failed: ", cublasGetStatusString(status)); + return true; + }); + }); + }); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 82412b6e0..5f0ae34eb 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -140,6 +140,15 @@ void cublas_grouped_gemm( const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream); +void bmm_fp8( + at::Tensor A, + at::Tensor B, + at::Tensor D, + at::Tensor A_scale, + at::Tensor B_scale, + at::Tensor workspace_buffer, + int64_t cublas_handle, + int64_t cuda_stream); /* * From csrc/moe @@ -198,15 +207,6 @@ void build_tree_kernel( /* * From FlashInfer */ -void bmm_fp8( - at::Tensor A, - at::Tensor B, - at::Tensor D, - at::Tensor A_scale, - at::Tensor B_scale, - at::Tensor workspace_buffer, - int64_t cublas_handle, - int64_t cuda_stream); void min_p_sampling_from_probs( at::Tensor probs, at::Tensor uniform_samples, diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index cdc7f936c..712ed36cf 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,5 +1,10 @@ [build-system] -requires = ["setuptools>=61.0", "wheel", "torch"] +requires = [ + "setuptools>=61.0", + "scikit-build-core>=0.10", + "torch==2.5.1", + "wheel", +] build-backend = "setuptools.build_meta" [project] diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index d76a2668a..0c273f97d 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -97,6 +97,8 @@ sources = [ "csrc/allreduce/trt_reduce_kernel.cu", "csrc/attention/lightning_attention_decode_kernel.cu", "csrc/elementwise/fused_add_rms_norm_kernel.cu", + "csrc/elementwise/rope.cu", + "csrc/gemm/bmm_fp8.cu", "csrc/gemm/cublas_grouped_gemm.cu", "csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu", @@ -109,11 +111,9 @@ sources = [ "csrc/speculative/speculative_sampling.cu", "csrc/torch_extension.cc", "3rdparty/flashinfer/csrc/activation.cu", - "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", - "3rdparty/flashinfer/csrc/sampling.cu", "3rdparty/flashinfer/csrc/renorm.cu", - "3rdparty/flashinfer/csrc/rope.cu", + "3rdparty/flashinfer/csrc/sampling.cu", ] enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"