diff --git a/.gitignore b/.gitignore index 73fd52992..91966c664 100644 --- a/.gitignore +++ b/.gitignore @@ -222,3 +222,6 @@ work_dirs/ compile_commands.json *.iml + +# VSCode +.vscode diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index b0554bd8f..ab9d68b44 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post14" +version = "0.0.2.post15" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 33e4abe1b..25319af7a 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -53,6 +53,7 @@ ext_modules = [ "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", + "src/sgl-kernel/csrc/rotary_embedding.cu", ], include_dirs=include_dirs, extra_compile_args={ diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 0c744982d..480bec71f 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -6,6 +6,7 @@ from sgl_kernel.ops import ( int8_scaled_mm, moe_align_block_size, register_graph_buffers, + rotary_embedding, sampling_scaling_penalties, ) @@ -18,4 +19,5 @@ __all__ = [ "sampling_scaling_penalties", "get_graph_buffer_ipc_meta", "register_graph_buffers", + "rotary_embedding", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu new file mode 100644 index 000000000..1dd4c4c52 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu @@ -0,0 +1,119 @@ +// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu + +#include +#include +#include + +template +inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } +} + +template +__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + rotary_embedding_kernel + <<>>(positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + rotary_embedding_kernel + <<>>(positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } + }); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 99d0326cf..f2ae95d7f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias); +// rotary embedding +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); // int8_scaled_mm m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); + // rotary embedding + m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 6b35f78a4..b8abd57d3 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -7,6 +7,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers +from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, ) @@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): out_dtype, bias, ) + + +def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): + return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py new file mode 100644 index 000000000..1bbe8f1bf --- /dev/null +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -0,0 +1,118 @@ +from typing import Optional, Tuple + +import torch +from vllm.model_executor.layers.rotary_embedding import ( + RotaryEmbedding as VLLMRotaryEmbedding, +) + + +class SGLRotaryEmbedding(VLLMRotaryEmbedding): + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from sgl_kernel import rotary_embedding + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + + rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + +# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native + + +def test_rotary_embedding(): + # Test case 1: FP32 + def run_test( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + batch_size, + seq_len, + num_heads, + test_name, + ): + print(f"\nRunning {test_name}...") + # Initialize both implementations + sgl_rope = SGLRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ).to("cuda") + vllm_rope = VLLMRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ).to("cuda") + + # Regular forward pass + positions = torch.arange(seq_len, device="cuda").repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype + ) + key = torch.randn( + batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype + ) + + # Make copies for both implementations + query_sgl = query.clone() + key_sgl = key.clone() + query_vllm = query.clone() + key_vllm = key.clone() + + # Run both implementations + query_sgl_out, key_sgl_out = sgl_rope.forward_cuda( + positions, query_sgl, key_sgl + ) + query_vllm_out, key_vllm_out = vllm_rope.forward_native( + positions, query_vllm, key_vllm + ) + + # Compare outputs + torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3) + + print(f"{test_name} passed!") + + # Test Case 1: FP32 with larger dimensions + run_test( + head_size=128, + rotary_dim=64, + max_position=4096, + base=10000, + is_neox_style=True, + dtype=torch.float32, + batch_size=4, + seq_len=32, + num_heads=8, + test_name="FP32 Test", + ) + + # Test Case 2: BF16 with smaller dimensions + run_test( + head_size=64, + rotary_dim=32, + max_position=2048, + base=8000, + is_neox_style=True, + dtype=torch.bfloat16, + batch_size=2, + seq_len=16, + num_heads=4, + test_name="BF16 Test", + ) + + +if __name__ == "__main__": + test_rotary_embedding()