init
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from .rotary_embedding import (
|
||||
RotaryEmbedding_init_vacc,
|
||||
RotaryEmbedding_forward_vacc,
|
||||
ScalingRotaryEmbedding_forward_vacc,
|
||||
_compute_inv_freq_vacc,
|
||||
_deepseek_compute_cos_sin_cache_vacc,
|
||||
_yarn_compute_cos_sin_cache_vacc,
|
||||
_compute_cos_sin_cache_vacc
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
101
vllm_vacc/vllm/model_executor/layers/rotary_embedding/mrope.py
Normal file
101
vllm_vacc/vllm/model_executor/layers/rotary_embedding/mrope.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import itertools
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MRotaryEmbedding:
|
||||
@classmethod
|
||||
def _qwen3vl_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
|
||||
for _ in range(t)]
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
return llm_positions, mrope_position_delta
|
||||
@@ -0,0 +1,203 @@
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# from vllm.model_executor.layers.rotary_embedding import _apply_rotary_emb
|
||||
# from vllm.model_executor.layers.rotary_embedding import _yarn_find_correction_range, _yarn_linear_ramp_mask
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_find_correction_range as _yarn_find_correction_range
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_linear_ramp_mask as _yarn_linear_ramp_mask
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from ...ops.mrope_op import get_sin_cos_mrope
|
||||
|
||||
def RotaryEmbedding_init_vacc(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super(CustomOp, self).__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
# cache = self._compute_cos_sin_cache()
|
||||
cos, sin = self._compute_cos_sin_cache()
|
||||
cos = cos.to(dtype)
|
||||
sin = sin.to(dtype)
|
||||
|
||||
self.register_buffer("cos_cache", cos, persistent=False)
|
||||
self.register_buffer("sin_cache", sin, persistent=False)
|
||||
|
||||
# cache = cache.to(dtype)
|
||||
# self.cos_sin_cache: torch.Tensor
|
||||
# self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def RotaryEmbedding_forward_vacc(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-vacc implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
num_tokens = positions.numel()
|
||||
# positions = positions.flatten()
|
||||
# num_tokens = positions.shape[0]
|
||||
# cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
# cos_sin = self.cos_sin_cache[positions]
|
||||
# cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
|
||||
if isinstance(self, MRotaryEmbedding):
|
||||
# get mrope sin/cos
|
||||
cos, sin = get_sin_cos_mrope(self, positions)
|
||||
num_tokens = num_tokens//3
|
||||
else:
|
||||
positions = positions.flatten()
|
||||
cos = self.cos_cache[positions]
|
||||
sin = self.sin_cache[positions]
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
mode = "neox"
|
||||
if not self.is_neox_style:
|
||||
mode = "gptj"
|
||||
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
return query, key
|
||||
|
||||
def ScalingRotaryEmbedding_forward_vacc(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
# self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
# positions.device)
|
||||
# cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
# if offsets is not None else positions]
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
|
||||
# cos_sin = self.cos_sin_cache[positions]
|
||||
# cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = self.cos_cache[positions]
|
||||
sin = self.sin_cache[positions]
|
||||
|
||||
# TODO: to be removed (require odsp support)
|
||||
# if self.is_neox_style:
|
||||
# # NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# # shape [batch_size, seq_len].
|
||||
# cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
# sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
# else:
|
||||
# cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
# sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
# rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
mode = "neox" if self.is_neox_style else "gptj"
|
||||
# query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
# key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
query_rot, key_rot=torch.vacc.RotaryPosEmbedding(query_rot, key_rot, cos, sin, 0, mode)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def _compute_inv_freq_vacc(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device=current_platform.device_type) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _deepseek_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# NOTE: for odsp friendly
|
||||
# seperate cos/sin cache can gurantee cos/sin
|
||||
# always has contigous layout for dim[-1]
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
def _yarn_compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
def _compute_cos_sin_cache_vacc(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
# NOTE: for odsp friendly
|
||||
# seperate cos/sin cache can gurantee cos/sin
|
||||
# always has contigous layout for dim[-1]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
return cos, sin
|
||||
# cache = torch.cat((cos, sin), dim=-1)
|
||||
# return cache
|
||||
|
||||
|
||||
# import vllm.model_executor.layers.rotary_embedding as rotary_embedding
|
||||
# rotary_embedding.RotaryEmbedding.forward_vacc=RotaryEmbedding_forward_vacc
|
||||
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_inv_freq=_compute_inv_freq_vacc
|
||||
# rotary_embedding.DeepseekScalingRotaryEmbedding._compute_cos_sin_cache=_compute_cos_sin_cache_vacc
|
||||
|
||||
Reference in New Issue
Block a user