From 271d3d0d50d324e4cfffa32a6ba3b3759124c30d Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Mon, 20 Oct 2025 11:51:07 +0800 Subject: [PATCH] Support mrope triton kernel and add unit test (#11722) Co-authored-by: luoyuan.luo Co-authored-by: b8zhong --- python/sglang/srt/layers/rotary_embedding.py | 240 +++++++++++++++++- sgl-kernel/benchmark/bench_mrope.py | 250 +++++++++++++++++++ test/srt/rotary_embedding/test_mrope.py | 140 +++++++++++ test/srt/run_suite.py | 1 + 4 files changed, 630 insertions(+), 1 deletion(-) create mode 100644 sgl-kernel/benchmark/bench_mrope.py create mode 100644 test/srt/rotary_embedding/test_mrope.py diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index e4ca62c86..6ce25651f 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -7,6 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn +import triton +import triton.language as tl from sglang.srt.custom_op import CustomOp from sglang.srt.utils import ( @@ -1033,6 +1035,188 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T return x_t +@triton.jit +def _triton_mrope_forward( + q_ptr, + k_ptr, + cos, + sin, + num_tokens, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + rd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + mrope_section_w: tl.constexpr, + is_interleaved: tl.constexpr, +): + # Adapted from + # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py + # This version supports flatten input tensors from vllm + # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) + # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary + pid = tl.program_id(0) + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) + + # Updated stride calculation for half head_dim + half_rd = rd // 2 + t_cos = cos + pid * half_rd + h_cos = t_cos + num_tokens * half_rd + w_cos = h_cos + num_tokens * half_rd + t_sin = sin + pid * half_rd + h_sin = t_sin + num_tokens * half_rd + w_sin = h_sin + num_tokens * half_rd + + # Updated offsets for half head_dim + cos_offsets = tl.arange(0, pad_hd // 2) + if is_interleaved: + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) + else: + t_end = mrope_section_t + h_end = t_end + mrope_section_h + t_mask = cos_offsets < mrope_section_t + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) + + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_half_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) + + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( + sin_row.dtype + ) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (rd // 2) + second_half_k_offsets = first_half_k_offsets + (rd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( + sin_row.dtype + ) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + # Since cos and sin are now half-size, + # we use the same cos_row and sin_row for both halves + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def triton_mrope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mrope_section: list[int], + head_size: int, + rotary_dim: int, + mrope_interleaved: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """The mrope triton kernel. + + Args: + q: [num_tokens, num_heads * head_size] + k: [num_tokens, num_kv_heads * head_size] + cos: [3, num_tokens, head_size //2 ] + (T/H/W positions with multimodal inputs) + sin: [3, num_tokens, head_size //2 ] + (T/H/W positions with multimodal inputs) + mrope_section: [t, h, w] + head_size: int + """ + n_row, n_q_head_head_dim = q.shape + assert ( + n_q_head_head_dim % head_size == 0 + ), f"q shape {n_q_head_head_dim} must be divisible by head_size {head_size}" + n_q_head = n_q_head_head_dim // head_size + assert ( + k.shape[1] % head_size == 0 + ), f"k shape {k.shape[1]} must be divisible by head_size {head_size}" + n_kv_head = k.shape[1] // head_size + pad_hd = triton.next_power_of_2(head_size) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + # ensure tensors passed into the kernel are contiguous. + # It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_mrope_forward[(n_row,)]( + q, + k, + cos, + sin, + n_row, + n_q_head, + n_kv_head, + head_size, + rotary_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + mrope_section[2], + mrope_interleaved, + ) + return q, k + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1086,8 +1270,17 @@ class MRotaryEmbedding(RotaryEmbedding): f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" ) + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if ( + self.cos_sin_cache.device != query.device + or self.cos_sin_cache.dtype != query.dtype + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + @torch.compile(dynamic=True, backend=get_compiler_backend()) - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -1141,6 +1334,51 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + self._match_cos_sin_cache_dtype(query) + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + if positions.ndim == 2: + assert self.mrope_section + + q, k = triton_mrope( + query, + key, + cos, + sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + self.mrope_interleaved, + ) + + return q.reshape(query_shape), k.reshape(key_shape) + + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439 @staticmethod def get_rope_index( diff --git a/sgl-kernel/benchmark/bench_mrope.py b/sgl-kernel/benchmark/bench_mrope.py new file mode 100644 index 000000000..ec0a15854 --- /dev/null +++ b/sgl-kernel/benchmark/bench_mrope.py @@ -0,0 +1,250 @@ +# Adapted from vLLM benchmark_mrope.py + +# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models). +# It generates test data, runs benchmarks, and saves results to a CSV file. +# +# The CSV file (named with current date/time) contains these columns: +# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, +# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99, +# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, +# speedup +# +# == Usage Examples == +# +# Single model benchmark: +# python3 benchmark_mrope.py --model-name Qwen/Qwen2.5-VL-7B-Instruct --tp-size 8 \ +# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 + +import argparse +import time +from typing import Any + +import numpy as np +import torch +from transformers import AutoConfig + +from sglang.srt.layers.rotary_embedding import get_rope + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def get_model_config(model_name: str): + """Get model configuration parameters""" + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + return config + + +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): + """Generate test data for given configuration.""" + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) + + # Create query and key tensors + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) + + return positions, query, key + + +def calculate_stats(times: list[float]) -> dict[str, float]: + """Calculate statistics from a list of times.""" + times_array = np.array(times) + return { + "mean": np.mean(times_array), + "median": np.median(times_array), + "p99": np.percentile(times_array, 99), + "min": np.min(times_array), + "max": np.max(times_array), + } + + +def benchmark_mrope( + model_name: str, + num_tokens: int, + head_dim: int, + tp_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 8192, + rope_theta: float = 10000, + is_neox_style: bool = True, + rope_scaling: dict[str, Any] = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 0, + warmup_iter: int = 10, + benchmark_iter: int = 100, +): + torch.manual_seed(seed) + torch.set_default_device(device) + # the parameters to compute the q k v size based on tp_size + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=head_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=rope_scaling, + dtype=dtype, + ).to(device=device) + + print(80 * "=") + print( + f"Evaluating model: {model_name} " + f"with tp_size: {tp_size} " + f"and num_tokens: {num_tokens}, " + f"dtype: {dtype}" + ) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) + + # Warm up + for _ in range(warmup_iter): + mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + mrope_helper_class.forward( + positions, + query.clone(), + key.clone(), + ) + + torch.cuda.synchronize() + + # Time reference implementation + torch_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + + mrope_helper_class.forward_native( + positions, + query_clone, + key_clone, + ) + + torch.cuda.synchronize() + torch_times.append(time.time() - start_time) + + # Time triton kernel implementation + triton_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + mrope_helper_class.forward( + positions, + query_clone, + key_clone, + ) + torch.cuda.synchronize() + triton_times.append(time.time() - start_time) + + # Calculate statistics + torch_stats = calculate_stats(torch_times) + triton_stats = calculate_stats(triton_times) + print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):") + + print( + f"Torch implementation: " + f"mean={torch_stats['mean']:.8f}s, " + f"median={torch_stats['median']:.8f}s, " + f"p99={torch_stats['p99']:.8f}s" + ) + + print( + f"Triton implementation: " + f"mean={triton_stats['mean']:.8f}s, " + f"median={triton_stats['median']:.8f}s, " + f"p99={triton_stats['p99']:.8f}s" + ) + + print( + f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x" + ) + + return torch_stats, triton_stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark the rotary embedding kernels." + ) + parser.add_argument("--model-name", type=str, default="") + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--warmup-iter", type=int, default=10) + parser.add_argument("--benchmark-iter", type=int, default=100) + parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-tokens", type=int, nargs="+", required=False) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + print(args) + + model_tp_dict = {} + if args.model_name == "": + model_tp_dict = { + "Qwen/Qwen2-VL-2B-Instruct": [1], + "Qwen/Qwen2-VL-7B-Instruct": [1], + "Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8], + "Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8], + } + else: + model_tp_dict[args.model_name] = [args.tp_size] + + if args.num_tokens is None: + num_tokens_list = [2**i for i in range(0, 18)] + else: + num_tokens_list = args.num_tokens + + for model_name, tp_list in model_tp_dict.items(): + for tp_size in tp_list: + config = get_model_config(model_name) + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + is_neox_style = True + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + + for num_tokens in num_tokens_list: + benchmark_mrope( + model_name=model_name, + num_tokens=num_tokens, + head_dim=head_dim, + tp_size=tp_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=getattr(torch, args.dtype), + seed=args.seed, + warmup_iter=args.warmup_iter, + benchmark_iter=args.benchmark_iter, + ) diff --git a/test/srt/rotary_embedding/test_mrope.py b/test/srt/rotary_embedding/test_mrope.py new file mode 100644 index 000000000..ad6412ec4 --- /dev/null +++ b/test/srt/rotary_embedding/test_mrope.py @@ -0,0 +1,140 @@ +from typing import NamedTuple + +import pytest +import torch +from packaging.version import Version +from transformers import AutoConfig +from transformers import __version__ as TRANSFORMERS_VERSION + +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.utils import ( + cpu_has_amx_support, + is_cpu, + is_cuda, + is_hip, + is_npu, + is_xpu, +) + +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_cpu = is_cpu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_npu = is_npu() +_is_xpu = is_xpu() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): + """Generate test data for given configuration.""" + torch.manual_seed(42) + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) + + # Create query and key tensors + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) + + return positions, query, key + + +class MRoPETestInfo(NamedTuple): + model_name: str + atol: float = 1e-2 + rtol: float = 1.6e-2 + marks: list[pytest.MarkDecorator] = [] + + +TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version + +MODELS_TO_TEST = [ + MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), +] + +num_tokens_list = [11, 8192] + + +@pytest.mark.skipif(not _is_cuda, reason="Skipping CUDA/ROCm only tests.") +@pytest.mark.parametrize( + "model_info, model_name", + [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST + ], +) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +def test_mrope( + model_name: str, + model_info: MRoPETestInfo, + tp_size: int, + dtype: torch.dtype, + num_tokens: int, +): + atol = model_info.atol + rtol = model_info.rtol + + config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() + + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // total_num_heads + ) + is_neox_style = True + + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + rotary_dim = int(head_dim * partial_rotary_factor) + + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=dtype, + ).to(device=device) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) + + query_native, key_native = mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + query_cuda, key_cuda = mrope_helper_class.forward( + positions, + query.clone(), + key.clone(), + ) + + torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol) + torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e14faec2b..29f332800 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -77,6 +77,7 @@ suites = { TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_fa3.py", 376), # TestFile("test_flashmla.py", 352), + TestFile("rotary_embedding/test_mrope.py", 300), TestFile("test_function_call_parser.py", 10), TestFile("test_fused_moe.py", 30), TestFile("test_gpt_oss_1gpu.py", 600),