diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 6e6f38d35..4f1f08989 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 6e6f38d3534994c34b2c6b09b5b45c8a7b92ffd2 +Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5 diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 20cccb113..6745d2e80 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -94,6 +94,7 @@ sources = [ "3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/sampling.cu", "3rdparty/flashinfer/csrc/renorm.cu", + "3rdparty/flashinfer/csrc/rope.cu", ] enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index df141dee1..e82eece48 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,4 +1,5 @@ from sgl_kernel.ops import ( + apply_rope_with_cos_sin_cache_inplace, bmm_fp8, custom_dispose, custom_reduce, @@ -25,6 +26,7 @@ from sgl_kernel.ops import ( ) __all__ = [ + "apply_rope_with_cos_sin_cache_inplace", "bmm_fp8", "custom_dispose", "custom_reduce", diff --git a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu index 1dd4c4c52..d02554fb1 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu @@ -98,7 +98,7 @@ void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [nu int64_t query_stride = query.stride(-2); int64_t key_stride = key.stride(-2); - dim3 grid(num_tokens); + dim3 grid(num_tokens); // each block is responsible for one token dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 93c53c1e9..f03a09364 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -112,3 +112,7 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, std::optional maybe_top_p_arr, double top_p_val, bool deterministic, int64_t cuda_stream); + +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, + int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index ced0dafa9..3543d7423 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -10,6 +10,60 @@ from sgl_kernel.ops.utils import ( ) +def apply_rope_with_cos_sin_cache_inplace( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +) -> None: + r""" + Apply rotary embedding to keys and queries with precomputed cos/sin values. + This is designed to be compatible with the SGL/vLLM implementation. + The result is inplace applied to the input tensors. + + Parameters + ---------- + positions : torch.Tensor + Position indices, shape: ``(nnz)``. + query : torch.Tensor + Query tensor, shape: ``(nnz, num_q_heads * head_size)``. + key : torch.Tensor + Key tensor, shape: ``(nnz, num_k_heads * head_size)``. + cos_sin_cache : torch.Tensor + Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. + Cosine is the first half and Sine is the second half on rotary_dim. + is_neox : bool + Whether to use Neox style RoPE, default: ``True``. + + * If ``True``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + Note + ---- + The rotary dimension is determined by the cosine cache and sine cache. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + with query.device as device: + pos_ids = pos_ids.int() + torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( + q=query.view(query.shape[0], -1, head_size), + k=key.view(key.shape[0], -1, head_size), + q_rope=query.view(query.shape[0], -1, head_size), + k_rope=key.view(key.shape[0], -1, head_size), + cos_sin_cache=cos_sin_cache, + pos_ids=positions, + interleave=(not is_neox), + cuda_stream=_get_cuda_stream(device), + ) + + def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index caf4f1269..70cdde9d8 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -1,4 +1,3 @@ - #include #include @@ -116,6 +115,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + + // apply rope with cos sin cache + m.def( + "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " + "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); } REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py index 1bbe8f1bf..901b69236 100644 --- a/sgl-kernel/tests/test_rotary_embedding.py +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -1,13 +1,127 @@ -from typing import Optional, Tuple +import math +from typing import Any, Dict, List, Optional, Tuple, Union +import pytest import torch -from vllm.model_executor.layers.rotary_embedding import ( - RotaryEmbedding as VLLMRotaryEmbedding, -) +import torch.nn as nn +from sgl_kernel import apply_rope_with_cos_sin_cache_inplace -class SGLRotaryEmbedding(VLLMRotaryEmbedding): +# vLLM torch native +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + # Modification: float32 is required for the rotary embedding to work correctly + query = query.to(torch.float32) + key = key.to(torch.float32) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + # Modification: convert to the correct dtype + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +class FlashInferRotaryEmbedding(RotaryEmbedding): def forward_cuda( self, positions: torch.Tensor, @@ -15,104 +129,70 @@ class SGLRotaryEmbedding(VLLMRotaryEmbedding): key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - from sgl_kernel import rotary_embedding - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - - rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, ) + return query, key -# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), + ], +) +def test_correctness( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ).to(device) + rope_flashinfer = FlashInferRotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ).to(device) - -def test_rotary_embedding(): - # Test case 1: FP32 - def run_test( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - batch_size, - seq_len, - num_heads, - test_name, - ): - print(f"\nRunning {test_name}...") - # Initialize both implementations - sgl_rope = SGLRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, dtype - ).to("cuda") - vllm_rope = VLLMRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, dtype - ).to("cuda") - - # Regular forward pass - positions = torch.arange(seq_len, device="cuda").repeat(batch_size) - query = torch.randn( - batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype - ) - key = torch.randn( - batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype - ) - - # Make copies for both implementations - query_sgl = query.clone() - key_sgl = key.clone() - query_vllm = query.clone() - key_vllm = key.clone() - - # Run both implementations - query_sgl_out, key_sgl_out = sgl_rope.forward_cuda( - positions, query_sgl, key_sgl - ) - query_vllm_out, key_vllm_out = vllm_rope.forward_native( - positions, query_vllm, key_vllm - ) - - # Compare outputs - torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3) - - print(f"{test_name} passed!") - - # Test Case 1: FP32 with larger dimensions - run_test( - head_size=128, - rotary_dim=64, - max_position=4096, - base=10000, - is_neox_style=True, - dtype=torch.float32, - batch_size=4, - seq_len=32, - num_heads=8, - test_name="FP32 Test", + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device ) - # Test Case 2: BF16 with smaller dimensions - run_test( - head_size=64, - rotary_dim=32, - max_position=2048, - base=8000, - is_neox_style=True, - dtype=torch.bfloat16, - batch_size=2, - seq_len=16, - num_heads=4, - test_name="BF16 Test", + query_ref, key_ref = query.clone(), key.clone() + query_flashinfer, key_flashinfer = query.clone(), key.clone() + + query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) + query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( + pos_ids, query_flashinfer, key_flashinfer ) + print(query_ref_out) + print(query_flashinfer_out) -if __name__ == "__main__": - test_rotary_embedding() + torch.testing.assert_close( + query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)