413 lines
14 KiB
Python
413 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from .base import RotaryEmbeddingBase
|
|
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
|
|
|
|
|
@triton.jit
|
|
def _triton_mrope_forward(
|
|
q_ptr,
|
|
k_ptr,
|
|
cos,
|
|
sin,
|
|
num_tokens,
|
|
n_qh: tl.constexpr,
|
|
n_kh: tl.constexpr,
|
|
hd: tl.constexpr,
|
|
rd: tl.constexpr,
|
|
pad_n_qh: tl.constexpr,
|
|
pad_n_kh: tl.constexpr,
|
|
pad_hd: tl.constexpr,
|
|
mrope_section_t: tl.constexpr,
|
|
mrope_section_h: tl.constexpr,
|
|
mrope_section_w: tl.constexpr,
|
|
is_interleaved: tl.constexpr,
|
|
):
|
|
# Adapted from
|
|
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
|
# This version supports flatten input tensors from vllm
|
|
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
|
|
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
|
|
pid = tl.program_id(0)
|
|
# locate start address
|
|
q_ptr = q_ptr + pid * (n_qh * hd)
|
|
k_ptr = k_ptr + pid * (n_kh * hd)
|
|
|
|
# ####################################################################
|
|
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
|
# m of this program instance
|
|
# ####################################################################
|
|
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
|
|
|
|
# Updated stride calculation for half head_dim
|
|
half_rd = rd // 2
|
|
t_cos = cos + pid * half_rd
|
|
h_cos = t_cos + num_tokens * half_rd
|
|
w_cos = h_cos + num_tokens * half_rd
|
|
t_sin = sin + pid * half_rd
|
|
h_sin = t_sin + num_tokens * half_rd
|
|
w_sin = h_sin + num_tokens * half_rd
|
|
|
|
# Updated offsets for half head_dim
|
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
if is_interleaved:
|
|
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
|
|
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
|
|
t_mask = ~(h_mask | w_mask)
|
|
else:
|
|
t_end = mrope_section_t
|
|
h_end = t_end + mrope_section_h
|
|
t_mask = cos_offsets < mrope_section_t
|
|
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
|
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
|
|
|
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
|
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
|
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
|
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
|
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
|
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
|
|
|
cos_row = t_cos_row + h_cos_row + w_cos_row
|
|
sin_row = t_sin_row + h_sin_row + w_sin_row
|
|
|
|
# ####################################################################
|
|
# Load the left and right half of q and k for the current
|
|
# program instance (i.e. for the current token) separately
|
|
# ####################################################################
|
|
# left half of the head
|
|
first_half_q_offsets = (
|
|
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
)
|
|
first_half_k_offsets = (
|
|
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
)
|
|
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
|
)
|
|
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
|
)
|
|
|
|
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
sin_row.dtype
|
|
)
|
|
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
sin_row.dtype
|
|
)
|
|
|
|
# right half of the head
|
|
second_half_q_offsets = first_half_q_offsets + (rd // 2)
|
|
second_half_k_offsets = first_half_k_offsets + (rd // 2)
|
|
second_q_mask = first_q_mask
|
|
second_k_mask = first_k_mask
|
|
|
|
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
sin_row.dtype
|
|
)
|
|
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
sin_row.dtype
|
|
)
|
|
|
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
# Since cos and sin are now half-size,
|
|
# we use the same cos_row and sin_row for both halves
|
|
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
|
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
|
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
|
|
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
|
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
|
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
|
|
|
|
def triton_mrope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
mrope_section: list[int],
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
mrope_interleaved: bool,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Qwen2VL mrope kernel.
|
|
|
|
Args:
|
|
q: [num_tokens, num_heads * head_size]
|
|
k: [num_tokens, num_kv_heads * head_size]
|
|
cos: [3, num_tokens, head_size //2 ]
|
|
(T/H/W positions with multimodal inputs)
|
|
sin: [3, num_tokens, head_size //2 ]
|
|
(T/H/W positions with multimodal inputs)
|
|
mrope_section: [t, h, w]
|
|
head_size: int
|
|
"""
|
|
n_row, n_q_head_head_dim = q.shape
|
|
n_q_head = n_q_head_head_dim // head_size
|
|
n_kv_head = k.shape[1] // head_size
|
|
pad_hd = triton.next_power_of_2(head_size)
|
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
|
|
# ensure tensors passed into the kernel are contiguous.
|
|
# It will be no-op if they are already contiguous
|
|
q = q.contiguous()
|
|
k = k.contiguous()
|
|
cos = cos.contiguous()
|
|
sin = sin.contiguous()
|
|
|
|
_triton_mrope_forward[(n_row,)](
|
|
q,
|
|
k,
|
|
cos,
|
|
sin,
|
|
n_row,
|
|
n_q_head,
|
|
n_kv_head,
|
|
head_size,
|
|
rotary_dim,
|
|
pad_n_q_head,
|
|
pad_n_kv_head,
|
|
pad_hd,
|
|
mrope_section[0],
|
|
mrope_section[1],
|
|
mrope_section[2],
|
|
mrope_interleaved,
|
|
)
|
|
return q, k
|
|
|
|
|
|
def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
|
|
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
|
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
|
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
|
"""
|
|
x_t = x[0].clone()
|
|
x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
|
|
x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
|
|
return x_t
|
|
|
|
|
|
class MRotaryEmbedding(RotaryEmbeddingBase):
|
|
"""Rotary Embedding with Multimodal Sections."""
|
|
|
|
def __init__(
|
|
self,
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
max_position_embeddings: int,
|
|
base: float,
|
|
is_neox_style: bool,
|
|
dtype: torch.dtype,
|
|
mrope_section: list[int] | None = None,
|
|
mrope_interleaved: bool = False,
|
|
# YaRN parameters.
|
|
*,
|
|
scaling_factor: float | None = None,
|
|
extrapolation_factor: float = 1,
|
|
attn_factor: float = 1,
|
|
beta_fast: int = 32,
|
|
beta_slow: int = 1,
|
|
) -> None:
|
|
self.scaling_factor = scaling_factor
|
|
self.extrapolation_factor = extrapolation_factor
|
|
self.attn_factor = attn_factor
|
|
self.beta_fast = beta_fast
|
|
self.beta_slow = beta_slow
|
|
if self.scaling_factor is not None:
|
|
# Get n-d magnitude scaling corrected for interpolation
|
|
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
|
else:
|
|
self.mscale = 1.0
|
|
|
|
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
|
# the input video. We enlarge max_position_embeddings to 4 times to get
|
|
# a larger the cos and sin cache.
|
|
self.cache_max_position_num = max_position_embeddings * 4
|
|
super().__init__(
|
|
head_size,
|
|
rotary_dim,
|
|
self.cache_max_position_num,
|
|
base,
|
|
is_neox_style,
|
|
dtype,
|
|
)
|
|
|
|
self.mrope_section = mrope_section
|
|
self.mrope_interleaved = mrope_interleaved
|
|
if self.mrope_section:
|
|
assert sum(self.mrope_section) == rotary_dim // 2
|
|
|
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
|
if self.scaling_factor is None:
|
|
return super()._compute_inv_freq(base)
|
|
return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)
|
|
|
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
if self.scaling_factor is None:
|
|
return super()._compute_cos_sin_cache()
|
|
return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)
|
|
|
|
def forward_native(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor | None = None,
|
|
offsets: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
"""PyTorch-native implementation equivalent to forward().
|
|
|
|
Args:
|
|
positions:
|
|
[num_tokens,] (text only) or
|
|
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
|
query: [num_tokens, num_heads * head_size]
|
|
key: [num_tokens, num_kv_heads * head_size]
|
|
"""
|
|
assert positions.ndim == 1 or positions.ndim == 2
|
|
assert key is not None
|
|
|
|
self._match_cos_sin_cache_dtype(query)
|
|
num_tokens = positions.shape[-1]
|
|
cos_sin = self.cos_sin_cache[positions]
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
if positions.ndim == 2:
|
|
assert self.mrope_section
|
|
if self.mrope_interleaved:
|
|
cos = apply_interleaved_rope(cos, self.mrope_section)
|
|
sin = apply_interleaved_rope(sin, self.mrope_section)
|
|
else:
|
|
cos = torch.cat(
|
|
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
|
|
dim=-1,
|
|
)
|
|
sin = torch.cat(
|
|
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
|
|
dim=-1,
|
|
)
|
|
|
|
query_shape = query.shape
|
|
query = query.view(num_tokens, -1, self.head_size)
|
|
query_rot = query[..., : self.rotary_dim]
|
|
query_pass = query[..., self.rotary_dim :]
|
|
query_rot = self.apply_rotary_emb.forward_native(
|
|
query_rot,
|
|
cos,
|
|
sin,
|
|
)
|
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
|
|
|
key_shape = key.shape
|
|
key = key.view(num_tokens, -1, self.head_size)
|
|
key_rot = key[..., : self.rotary_dim]
|
|
key_pass = key[..., self.rotary_dim :]
|
|
key_rot = self.apply_rotary_emb.forward_native(
|
|
key_rot,
|
|
cos,
|
|
sin,
|
|
)
|
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
|
return query, key
|
|
|
|
def forward_cuda(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor | None = None,
|
|
offsets: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
assert positions.ndim == 1 or positions.ndim == 2
|
|
assert key is not None
|
|
|
|
self._match_cos_sin_cache_dtype(query)
|
|
num_tokens = positions.shape[-1]
|
|
cos_sin = self.cos_sin_cache[positions]
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
query_shape = query.shape
|
|
key_shape = key.shape
|
|
if positions.ndim == 2:
|
|
assert self.mrope_section
|
|
|
|
q, k = triton_mrope(
|
|
query,
|
|
key,
|
|
cos,
|
|
sin,
|
|
self.mrope_section,
|
|
self.head_size,
|
|
self.rotary_dim,
|
|
self.mrope_interleaved,
|
|
)
|
|
|
|
return q.reshape(query_shape), k.reshape(key_shape)
|
|
|
|
query = query.view(num_tokens, -1, self.head_size)
|
|
query_rot = query[..., : self.rotary_dim]
|
|
query_pass = query[..., self.rotary_dim :]
|
|
query_rot = self.apply_rotary_emb(
|
|
query_rot,
|
|
cos,
|
|
sin,
|
|
)
|
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
|
|
|
key = key.view(num_tokens, -1, self.head_size)
|
|
key_rot = key[..., : self.rotary_dim]
|
|
key_pass = key[..., self.rotary_dim :]
|
|
key_rot = self.apply_rotary_emb(
|
|
key_rot,
|
|
cos,
|
|
sin,
|
|
)
|
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
|
return query, key
|
|
|
|
def forward_cpu(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor | None = None,
|
|
offsets: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
return self.forward_native(positions, query, key, offsets)
|
|
|
|
@staticmethod
|
|
def get_next_input_positions(
|
|
mrope_position_delta: int,
|
|
context_len: int,
|
|
seq_len: int,
|
|
) -> list[list[int]]:
|
|
return [
|
|
list(
|
|
range(
|
|
context_len + mrope_position_delta, seq_len + mrope_position_delta
|
|
)
|
|
)
|
|
for _ in range(3)
|
|
]
|
|
|
|
@staticmethod
|
|
def get_next_input_positions_tensor(
|
|
out: np.ndarray,
|
|
out_offset: int,
|
|
mrope_position_delta: int,
|
|
context_len: int,
|
|
num_new_tokens: int,
|
|
):
|
|
values = np.arange(
|
|
mrope_position_delta + context_len,
|
|
mrope_position_delta + context_len + num_new_tokens,
|
|
dtype=out.dtype,
|
|
)
|
|
out[:, out_offset : out_offset + num_new_tokens] = values
|