113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
"""MRotaryEmbedding"""
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
|
|
class MRotaryEmbedding:
|
|
"""Rotary Embedding with Multimodal Sections."""
|
|
|
|
@staticmethod
|
|
def get_input_positions(
|
|
input_tokens: torch.Tensor,
|
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
|
vision_start_token_id: int,
|
|
spatial_merge_size: int,
|
|
context_len: int = 0,
|
|
) -> Tuple[List[List[int]], int]:
|
|
"""Get mrope input positions and delta value."""
|
|
|
|
if isinstance(image_grid_thw, torch.Tensor):
|
|
image_grid_thw = image_grid_thw.tolist()
|
|
|
|
vision_start_indices = torch.argwhere(
|
|
input_tokens == vision_start_token_id
|
|
).squeeze(1)
|
|
image_indices = vision_start_indices + 1
|
|
image_nums = image_indices.shape[0]
|
|
llm_pos_ids_list: list = []
|
|
|
|
st = 0
|
|
input_tokens_len = input_tokens.shape[0]
|
|
for image_index in range(image_nums):
|
|
ed = image_indices[image_index].item()
|
|
t, h, w = (
|
|
image_grid_thw[image_index][0],
|
|
image_grid_thw[image_index][1],
|
|
image_grid_thw[image_index][2],
|
|
)
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t,
|
|
h // spatial_merge_size,
|
|
w // spatial_merge_size,
|
|
)
|
|
text_len = ed - st
|
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
llm_pos_ids_list.append(
|
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
|
)
|
|
|
|
t_index = (
|
|
torch.arange(llm_grid_t)
|
|
.view(-1, 1)
|
|
.expand(-1, llm_grid_h * llm_grid_w)
|
|
.flatten()
|
|
)
|
|
h_index = (
|
|
torch.arange(llm_grid_h)
|
|
.view(1, -1, 1)
|
|
.expand(llm_grid_t, -1, llm_grid_w)
|
|
.flatten()
|
|
)
|
|
w_index = (
|
|
torch.arange(llm_grid_w)
|
|
.view(1, 1, -1)
|
|
.expand(llm_grid_t, llm_grid_h, -1)
|
|
.flatten()
|
|
)
|
|
llm_pos_ids_list.append(
|
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
)
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
|
|
if st < input_tokens_len:
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
text_len = input_tokens_len - st
|
|
llm_pos_ids_list.append(
|
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
|
)
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
llm_positions = llm_positions[:, context_len:]
|
|
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
|
|
return llm_positions.tolist(), mrope_position_delta
|
|
|
|
@staticmethod
|
|
def get_next_input_positions(
|
|
mrope_position_delta: int,
|
|
context_len: int,
|
|
seq_len: int,
|
|
) -> List[List[int]]:
|
|
return [
|
|
list(
|
|
range(
|
|
context_len + mrope_position_delta, seq_len + mrope_position_delta
|
|
)
|
|
)
|
|
for _ in range(3)
|
|
]
|