Support mrope triton kernel and add unit test (#11722)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -7,6 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import (
|
||||
@@ -1033,6 +1035,188 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
|
||||
return x_t
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _triton_mrope_forward(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
cos,
|
||||
sin,
|
||||
num_tokens,
|
||||
n_qh: tl.constexpr,
|
||||
n_kh: tl.constexpr,
|
||||
hd: tl.constexpr,
|
||||
rd: tl.constexpr,
|
||||
pad_n_qh: tl.constexpr,
|
||||
pad_n_kh: tl.constexpr,
|
||||
pad_hd: tl.constexpr,
|
||||
mrope_section_t: tl.constexpr,
|
||||
mrope_section_h: tl.constexpr,
|
||||
mrope_section_w: tl.constexpr,
|
||||
is_interleaved: tl.constexpr,
|
||||
):
|
||||
# Adapted from
|
||||
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
||||
# This version supports flatten input tensors from vllm
|
||||
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
|
||||
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
|
||||
pid = tl.program_id(0)
|
||||
# locate start address
|
||||
q_ptr = q_ptr + pid * (n_qh * hd)
|
||||
k_ptr = k_ptr + pid * (n_kh * hd)
|
||||
|
||||
# ####################################################################
|
||||
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
||||
# m of this program instance
|
||||
# ####################################################################
|
||||
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
|
||||
|
||||
# Updated stride calculation for half head_dim
|
||||
half_rd = rd // 2
|
||||
t_cos = cos + pid * half_rd
|
||||
h_cos = t_cos + num_tokens * half_rd
|
||||
w_cos = h_cos + num_tokens * half_rd
|
||||
t_sin = sin + pid * half_rd
|
||||
h_sin = t_sin + num_tokens * half_rd
|
||||
w_sin = h_sin + num_tokens * half_rd
|
||||
|
||||
# Updated offsets for half head_dim
|
||||
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||
if is_interleaved:
|
||||
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
|
||||
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
|
||||
t_mask = ~(h_mask | w_mask)
|
||||
else:
|
||||
t_end = mrope_section_t
|
||||
h_end = t_end + mrope_section_h
|
||||
t_mask = cos_offsets < mrope_section_t
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||
|
||||
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
||||
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
||||
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
||||
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
||||
|
||||
cos_row = t_cos_row + h_cos_row + w_cos_row
|
||||
sin_row = t_sin_row + h_sin_row + w_sin_row
|
||||
|
||||
# ####################################################################
|
||||
# Load the left and right half of q and k for the current
|
||||
# program instance (i.e. for the current token) separately
|
||||
# ####################################################################
|
||||
# left half of the head
|
||||
first_half_q_offsets = (
|
||||
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
||||
)
|
||||
first_half_k_offsets = (
|
||||
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
||||
)
|
||||
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
||||
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
||||
)
|
||||
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
||||
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
||||
)
|
||||
|
||||
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
||||
sin_row.dtype
|
||||
)
|
||||
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
||||
sin_row.dtype
|
||||
)
|
||||
|
||||
# right half of the head
|
||||
second_half_q_offsets = first_half_q_offsets + (rd // 2)
|
||||
second_half_k_offsets = first_half_k_offsets + (rd // 2)
|
||||
second_q_mask = first_q_mask
|
||||
second_k_mask = first_k_mask
|
||||
|
||||
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
||||
sin_row.dtype
|
||||
)
|
||||
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
||||
sin_row.dtype
|
||||
)
|
||||
|
||||
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
||||
# Since cos and sin are now half-size,
|
||||
# we use the same cos_row and sin_row for both halves
|
||||
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
||||
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
||||
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
||||
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
||||
|
||||
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
||||
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
||||
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
||||
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
||||
|
||||
|
||||
def triton_mrope(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
mrope_section: list[int],
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
mrope_interleaved: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The mrope triton kernel.
|
||||
|
||||
Args:
|
||||
q: [num_tokens, num_heads * head_size]
|
||||
k: [num_tokens, num_kv_heads * head_size]
|
||||
cos: [3, num_tokens, head_size //2 ]
|
||||
(T/H/W positions with multimodal inputs)
|
||||
sin: [3, num_tokens, head_size //2 ]
|
||||
(T/H/W positions with multimodal inputs)
|
||||
mrope_section: [t, h, w]
|
||||
head_size: int
|
||||
"""
|
||||
n_row, n_q_head_head_dim = q.shape
|
||||
assert (
|
||||
n_q_head_head_dim % head_size == 0
|
||||
), f"q shape {n_q_head_head_dim} must be divisible by head_size {head_size}"
|
||||
n_q_head = n_q_head_head_dim // head_size
|
||||
assert (
|
||||
k.shape[1] % head_size == 0
|
||||
), f"k shape {k.shape[1]} must be divisible by head_size {head_size}"
|
||||
n_kv_head = k.shape[1] // head_size
|
||||
pad_hd = triton.next_power_of_2(head_size)
|
||||
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
||||
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
||||
|
||||
# ensure tensors passed into the kernel are contiguous.
|
||||
# It will be no-op if they are already contiguous
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
cos = cos.contiguous()
|
||||
sin = sin.contiguous()
|
||||
|
||||
_triton_mrope_forward[(n_row,)](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
n_row,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_hd,
|
||||
mrope_section[0],
|
||||
mrope_section[1],
|
||||
mrope_section[2],
|
||||
mrope_interleaved,
|
||||
)
|
||||
return q, k
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
@@ -1086,8 +1270,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
||||
)
|
||||
|
||||
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if (
|
||||
self.cos_sin_cache.device != query.device
|
||||
or self.cos_sin_cache.dtype != query.dtype
|
||||
):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -1141,6 +1334,51 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
|
||||
@staticmethod
|
||||
def get_rope_index(
|
||||
|
||||
Reference in New Issue
Block a user