Support mrope triton kernel and add unit test (#11722)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -7,6 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -1033,6 +1035,188 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _triton_mrope_forward(
|
||||||
|
q_ptr,
|
||||||
|
k_ptr,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
num_tokens,
|
||||||
|
n_qh: tl.constexpr,
|
||||||
|
n_kh: tl.constexpr,
|
||||||
|
hd: tl.constexpr,
|
||||||
|
rd: tl.constexpr,
|
||||||
|
pad_n_qh: tl.constexpr,
|
||||||
|
pad_n_kh: tl.constexpr,
|
||||||
|
pad_hd: tl.constexpr,
|
||||||
|
mrope_section_t: tl.constexpr,
|
||||||
|
mrope_section_h: tl.constexpr,
|
||||||
|
mrope_section_w: tl.constexpr,
|
||||||
|
is_interleaved: tl.constexpr,
|
||||||
|
):
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
||||||
|
# This version supports flatten input tensors from vllm
|
||||||
|
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
|
||||||
|
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
# locate start address
|
||||||
|
q_ptr = q_ptr + pid * (n_qh * hd)
|
||||||
|
k_ptr = k_ptr + pid * (n_kh * hd)
|
||||||
|
|
||||||
|
# ####################################################################
|
||||||
|
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
||||||
|
# m of this program instance
|
||||||
|
# ####################################################################
|
||||||
|
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
|
||||||
|
|
||||||
|
# Updated stride calculation for half head_dim
|
||||||
|
half_rd = rd // 2
|
||||||
|
t_cos = cos + pid * half_rd
|
||||||
|
h_cos = t_cos + num_tokens * half_rd
|
||||||
|
w_cos = h_cos + num_tokens * half_rd
|
||||||
|
t_sin = sin + pid * half_rd
|
||||||
|
h_sin = t_sin + num_tokens * half_rd
|
||||||
|
w_sin = h_sin + num_tokens * half_rd
|
||||||
|
|
||||||
|
# Updated offsets for half head_dim
|
||||||
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||||
|
if is_interleaved:
|
||||||
|
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
|
||||||
|
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
|
||||||
|
t_mask = ~(h_mask | w_mask)
|
||||||
|
else:
|
||||||
|
t_end = mrope_section_t
|
||||||
|
h_end = t_end + mrope_section_h
|
||||||
|
t_mask = cos_offsets < mrope_section_t
|
||||||
|
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||||
|
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||||
|
|
||||||
|
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||||
|
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||||
|
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
||||||
|
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
||||||
|
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
||||||
|
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
||||||
|
|
||||||
|
cos_row = t_cos_row + h_cos_row + w_cos_row
|
||||||
|
sin_row = t_sin_row + h_sin_row + w_sin_row
|
||||||
|
|
||||||
|
# ####################################################################
|
||||||
|
# Load the left and right half of q and k for the current
|
||||||
|
# program instance (i.e. for the current token) separately
|
||||||
|
# ####################################################################
|
||||||
|
# left half of the head
|
||||||
|
first_half_q_offsets = (
|
||||||
|
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
||||||
|
)
|
||||||
|
first_half_k_offsets = (
|
||||||
|
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
||||||
|
)
|
||||||
|
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
||||||
|
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
||||||
|
)
|
||||||
|
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
||||||
|
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
||||||
|
sin_row.dtype
|
||||||
|
)
|
||||||
|
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
||||||
|
sin_row.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# right half of the head
|
||||||
|
second_half_q_offsets = first_half_q_offsets + (rd // 2)
|
||||||
|
second_half_k_offsets = first_half_k_offsets + (rd // 2)
|
||||||
|
second_q_mask = first_q_mask
|
||||||
|
second_k_mask = first_k_mask
|
||||||
|
|
||||||
|
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
||||||
|
sin_row.dtype
|
||||||
|
)
|
||||||
|
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
||||||
|
sin_row.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
||||||
|
# Since cos and sin are now half-size,
|
||||||
|
# we use the same cos_row and sin_row for both halves
|
||||||
|
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
||||||
|
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
||||||
|
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
||||||
|
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
||||||
|
|
||||||
|
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
||||||
|
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
||||||
|
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
||||||
|
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def triton_mrope(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
mrope_section: list[int],
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
mrope_interleaved: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""The mrope triton kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [num_tokens, num_heads * head_size]
|
||||||
|
k: [num_tokens, num_kv_heads * head_size]
|
||||||
|
cos: [3, num_tokens, head_size //2 ]
|
||||||
|
(T/H/W positions with multimodal inputs)
|
||||||
|
sin: [3, num_tokens, head_size //2 ]
|
||||||
|
(T/H/W positions with multimodal inputs)
|
||||||
|
mrope_section: [t, h, w]
|
||||||
|
head_size: int
|
||||||
|
"""
|
||||||
|
n_row, n_q_head_head_dim = q.shape
|
||||||
|
assert (
|
||||||
|
n_q_head_head_dim % head_size == 0
|
||||||
|
), f"q shape {n_q_head_head_dim} must be divisible by head_size {head_size}"
|
||||||
|
n_q_head = n_q_head_head_dim // head_size
|
||||||
|
assert (
|
||||||
|
k.shape[1] % head_size == 0
|
||||||
|
), f"k shape {k.shape[1]} must be divisible by head_size {head_size}"
|
||||||
|
n_kv_head = k.shape[1] // head_size
|
||||||
|
pad_hd = triton.next_power_of_2(head_size)
|
||||||
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
||||||
|
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
||||||
|
|
||||||
|
# ensure tensors passed into the kernel are contiguous.
|
||||||
|
# It will be no-op if they are already contiguous
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
cos = cos.contiguous()
|
||||||
|
sin = sin.contiguous()
|
||||||
|
|
||||||
|
_triton_mrope_forward[(n_row,)](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
n_row,
|
||||||
|
n_q_head,
|
||||||
|
n_kv_head,
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
pad_n_q_head,
|
||||||
|
pad_n_kv_head,
|
||||||
|
pad_hd,
|
||||||
|
mrope_section[0],
|
||||||
|
mrope_section[1],
|
||||||
|
mrope_section[2],
|
||||||
|
mrope_interleaved,
|
||||||
|
)
|
||||||
|
return q, k
|
||||||
|
|
||||||
|
|
||||||
class MRotaryEmbedding(RotaryEmbedding):
|
class MRotaryEmbedding(RotaryEmbedding):
|
||||||
"""Rotary Embedding with Multimodal Sections."""
|
"""Rotary Embedding with Multimodal Sections."""
|
||||||
|
|
||||||
@@ -1086,8 +1270,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
|
||||||
|
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||||
|
# is expensive, so avoid calling it if possible
|
||||||
|
if (
|
||||||
|
self.cos_sin_cache.device != query.device
|
||||||
|
or self.cos_sin_cache.dtype != query.dtype
|
||||||
|
):
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -1141,6 +1334,51 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert positions.ndim == 1 or positions.ndim == 2
|
||||||
|
assert key is not None
|
||||||
|
|
||||||
|
self._match_cos_sin_cache_dtype(query)
|
||||||
|
num_tokens = positions.shape[-1]
|
||||||
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
query_shape = query.shape
|
||||||
|
key_shape = key.shape
|
||||||
|
if positions.ndim == 2:
|
||||||
|
assert self.mrope_section
|
||||||
|
|
||||||
|
q, k = triton_mrope(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
self.mrope_section,
|
||||||
|
self.head_size,
|
||||||
|
self.rotary_dim,
|
||||||
|
self.mrope_interleaved,
|
||||||
|
)
|
||||||
|
|
||||||
|
return q.reshape(query_shape), k.reshape(key_shape)
|
||||||
|
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
query_rot = query[..., : self.rotary_dim]
|
||||||
|
query_pass = query[..., self.rotary_dim :]
|
||||||
|
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
key_rot = key[..., : self.rotary_dim]
|
||||||
|
key_pass = key[..., self.rotary_dim :]
|
||||||
|
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
|
return query, key
|
||||||
|
|
||||||
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
|
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_rope_index(
|
def get_rope_index(
|
||||||
|
|||||||
250
sgl-kernel/benchmark/bench_mrope.py
Normal file
250
sgl-kernel/benchmark/bench_mrope.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
# Adapted from vLLM benchmark_mrope.py
|
||||||
|
|
||||||
|
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
|
||||||
|
# It generates test data, runs benchmarks, and saves results to a CSV file.
|
||||||
|
#
|
||||||
|
# The CSV file (named with current date/time) contains these columns:
|
||||||
|
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
|
||||||
|
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
|
||||||
|
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
|
||||||
|
# speedup
|
||||||
|
#
|
||||||
|
# == Usage Examples ==
|
||||||
|
#
|
||||||
|
# Single model benchmark:
|
||||||
|
# python3 benchmark_mrope.py --model-name Qwen/Qwen2.5-VL-7B-Instruct --tp-size 8 \
|
||||||
|
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name: str):
|
||||||
|
"""Get model configuration parameters"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_data(
|
||||||
|
num_tokens: int,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Generate test data for given configuration."""
|
||||||
|
# Create 2D positions (3, num_tokens) for multimodal case
|
||||||
|
positions = torch.randint(
|
||||||
|
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create query and key tensors
|
||||||
|
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||||
|
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return positions, query, key
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||||
|
"""Calculate statistics from a list of times."""
|
||||||
|
times_array = np.array(times)
|
||||||
|
return {
|
||||||
|
"mean": np.mean(times_array),
|
||||||
|
"median": np.median(times_array),
|
||||||
|
"p99": np.percentile(times_array, 99),
|
||||||
|
"min": np.min(times_array),
|
||||||
|
"max": np.max(times_array),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_mrope(
|
||||||
|
model_name: str,
|
||||||
|
num_tokens: int,
|
||||||
|
head_dim: int,
|
||||||
|
tp_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 8192,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
rope_scaling: dict[str, Any] = None,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
seed: int = 0,
|
||||||
|
warmup_iter: int = 10,
|
||||||
|
benchmark_iter: int = 100,
|
||||||
|
):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
# the parameters to compute the q k v size based on tp_size
|
||||||
|
mrope_helper_class = get_rope(
|
||||||
|
head_size=head_dim,
|
||||||
|
rotary_dim=head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=rope_theta,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device=device)
|
||||||
|
|
||||||
|
print(80 * "=")
|
||||||
|
print(
|
||||||
|
f"Evaluating model: {model_name} "
|
||||||
|
f"with tp_size: {tp_size} "
|
||||||
|
f"and num_tokens: {num_tokens}, "
|
||||||
|
f"dtype: {dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# create q k v input tensors
|
||||||
|
# create rotary pos emb input tensors
|
||||||
|
positions, query, key = generate_test_data(
|
||||||
|
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
for _ in range(warmup_iter):
|
||||||
|
mrope_helper_class.forward_native(
|
||||||
|
positions,
|
||||||
|
query.clone(),
|
||||||
|
key.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mrope_helper_class.forward(
|
||||||
|
positions,
|
||||||
|
query.clone(),
|
||||||
|
key.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Time reference implementation
|
||||||
|
torch_times = []
|
||||||
|
for _ in range(benchmark_iter):
|
||||||
|
query_clone = query.clone()
|
||||||
|
key_clone = key.clone()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
mrope_helper_class.forward_native(
|
||||||
|
positions,
|
||||||
|
query_clone,
|
||||||
|
key_clone,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Time triton kernel implementation
|
||||||
|
triton_times = []
|
||||||
|
for _ in range(benchmark_iter):
|
||||||
|
query_clone = query.clone()
|
||||||
|
key_clone = key.clone()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
mrope_helper_class.forward(
|
||||||
|
positions,
|
||||||
|
query_clone,
|
||||||
|
key_clone,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
triton_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
torch_stats = calculate_stats(torch_times)
|
||||||
|
triton_stats = calculate_stats(triton_times)
|
||||||
|
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Torch implementation: "
|
||||||
|
f"mean={torch_stats['mean']:.8f}s, "
|
||||||
|
f"median={torch_stats['median']:.8f}s, "
|
||||||
|
f"p99={torch_stats['p99']:.8f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Triton implementation: "
|
||||||
|
f"mean={triton_stats['mean']:.8f}s, "
|
||||||
|
f"median={triton_stats['median']:.8f}s, "
|
||||||
|
f"p99={triton_stats['p99']:.8f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch_stats, triton_stats
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark the rotary embedding kernels."
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-name", type=str, default="")
|
||||||
|
parser.add_argument("--tp-size", type=int, default=1)
|
||||||
|
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||||
|
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||||
|
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
model_tp_dict = {}
|
||||||
|
if args.model_name == "":
|
||||||
|
model_tp_dict = {
|
||||||
|
"Qwen/Qwen2-VL-2B-Instruct": [1],
|
||||||
|
"Qwen/Qwen2-VL-7B-Instruct": [1],
|
||||||
|
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
|
||||||
|
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
|
||||||
|
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
model_tp_dict[args.model_name] = [args.tp_size]
|
||||||
|
|
||||||
|
if args.num_tokens is None:
|
||||||
|
num_tokens_list = [2**i for i in range(0, 18)]
|
||||||
|
else:
|
||||||
|
num_tokens_list = args.num_tokens
|
||||||
|
|
||||||
|
for model_name, tp_list in model_tp_dict.items():
|
||||||
|
for tp_size in tp_list:
|
||||||
|
config = get_model_config(model_name)
|
||||||
|
# get the model config
|
||||||
|
total_num_kv_heads = config.num_key_value_heads
|
||||||
|
total_num_heads = config.num_attention_heads
|
||||||
|
num_heads = total_num_heads // tp_size
|
||||||
|
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||||
|
head_dim = config.hidden_size // total_num_heads
|
||||||
|
is_neox_style = True
|
||||||
|
rope_theta = config.rope_theta
|
||||||
|
max_position = config.max_position_embeddings
|
||||||
|
|
||||||
|
for num_tokens in num_tokens_list:
|
||||||
|
benchmark_mrope(
|
||||||
|
model_name=model_name,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
head_dim=head_dim,
|
||||||
|
tp_size=tp_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
max_position=max_position,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
rope_scaling=config.rope_scaling,
|
||||||
|
dtype=getattr(torch, args.dtype),
|
||||||
|
seed=args.seed,
|
||||||
|
warmup_iter=args.warmup_iter,
|
||||||
|
benchmark_iter=args.benchmark_iter,
|
||||||
|
)
|
||||||
140
test/srt/rotary_embedding/test_mrope.py
Normal file
140
test/srt/rotary_embedding/test_mrope.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from packaging.version import Version
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
cpu_has_amx_support,
|
||||||
|
is_cpu,
|
||||||
|
is_cuda,
|
||||||
|
is_hip,
|
||||||
|
is_npu,
|
||||||
|
is_xpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_is_cpu = is_cpu()
|
||||||
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
|
_is_npu = is_npu()
|
||||||
|
_is_xpu = is_xpu()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_data(
|
||||||
|
num_tokens: int,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Generate test data for given configuration."""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
# Create 2D positions (3, num_tokens) for multimodal case
|
||||||
|
positions = torch.randint(
|
||||||
|
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create query and key tensors
|
||||||
|
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||||
|
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return positions, query, key
|
||||||
|
|
||||||
|
|
||||||
|
class MRoPETestInfo(NamedTuple):
|
||||||
|
model_name: str
|
||||||
|
atol: float = 1e-2
|
||||||
|
rtol: float = 1.6e-2
|
||||||
|
marks: list[pytest.MarkDecorator] = []
|
||||||
|
|
||||||
|
|
||||||
|
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
|
||||||
|
|
||||||
|
MODELS_TO_TEST = [
|
||||||
|
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
|
||||||
|
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
|
||||||
|
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
|
||||||
|
]
|
||||||
|
|
||||||
|
num_tokens_list = [11, 8192]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _is_cuda, reason="Skipping CUDA/ROCm only tests.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_info, model_name",
|
||||||
|
[
|
||||||
|
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||||
|
for test_config in MODELS_TO_TEST
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||||
|
def test_mrope(
|
||||||
|
model_name: str,
|
||||||
|
model_info: MRoPETestInfo,
|
||||||
|
tp_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
num_tokens: int,
|
||||||
|
):
|
||||||
|
atol = model_info.atol
|
||||||
|
rtol = model_info.rtol
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
config = config.get_text_config()
|
||||||
|
|
||||||
|
# get the model config
|
||||||
|
total_num_kv_heads = config.num_key_value_heads
|
||||||
|
total_num_heads = config.num_attention_heads
|
||||||
|
num_heads = total_num_heads // tp_size
|
||||||
|
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||||
|
head_dim = (
|
||||||
|
config.head_dim
|
||||||
|
if hasattr(config, "head_dim")
|
||||||
|
else config.hidden_size // total_num_heads
|
||||||
|
)
|
||||||
|
is_neox_style = True
|
||||||
|
|
||||||
|
rope_theta = config.rope_theta
|
||||||
|
max_position = config.max_position_embeddings
|
||||||
|
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||||
|
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||||
|
|
||||||
|
mrope_helper_class = get_rope(
|
||||||
|
head_size=head_dim,
|
||||||
|
rotary_dim=rotary_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=rope_theta,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
rope_scaling=config.rope_scaling,
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device=device)
|
||||||
|
|
||||||
|
# create q k v input tensors
|
||||||
|
# create rotary pos emb input tensors
|
||||||
|
positions, query, key = generate_test_data(
|
||||||
|
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
|
query_native, key_native = mrope_helper_class.forward_native(
|
||||||
|
positions,
|
||||||
|
query.clone(),
|
||||||
|
key.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
query_cuda, key_cuda = mrope_helper_class.forward(
|
||||||
|
positions,
|
||||||
|
query.clone(),
|
||||||
|
key.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
|
||||||
|
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
|
||||||
@@ -77,6 +77,7 @@ suites = {
|
|||||||
TestFile("test_eval_fp8_accuracy.py", 303),
|
TestFile("test_eval_fp8_accuracy.py", 303),
|
||||||
TestFile("test_fa3.py", 376),
|
TestFile("test_fa3.py", 376),
|
||||||
# TestFile("test_flashmla.py", 352),
|
# TestFile("test_flashmla.py", 352),
|
||||||
|
TestFile("rotary_embedding/test_mrope.py", 300),
|
||||||
TestFile("test_function_call_parser.py", 10),
|
TestFile("test_function_call_parser.py", 10),
|
||||||
TestFile("test_fused_moe.py", 30),
|
TestFile("test_fused_moe.py", 30),
|
||||||
TestFile("test_gpt_oss_1gpu.py", 600),
|
TestFile("test_gpt_oss_1gpu.py", 600),
|
||||||
|
|||||||
Reference in New Issue
Block a user