Support qwen2 vl model (#1721)
Co-authored-by: yizhang2077 <1109276519@qq.com> Co-authored-by: ispobock <ISPObaoke@163.com>
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
|
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
|
||||||
@@ -93,7 +93,7 @@ jobs:
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 17
|
python3 run_suite.py --suite minimal --range-begin 17
|
||||||
|
|||||||
@@ -133,6 +133,22 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="qwen2-vl",
|
||||||
|
default_system_prompt="You are a helpful assistant.",
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||||
|
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||||
|
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||||
|
},
|
||||||
|
style=ChatTemplateStyle.PLAIN,
|
||||||
|
stop_str=("<|im_end|>"),
|
||||||
|
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from sglang.srt.configs.exaone import ExaoneConfig
|
from sglang.srt.configs.exaone import ExaoneConfig
|
||||||
|
from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ExaoneConfig",
|
"ExaoneConfig",
|
||||||
|
"Qwen2VLConfig",
|
||||||
|
"Qwen2VLVisionConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
133
python/sglang/srt/configs/qwen2vl.py
Normal file
133
python/sglang/srt/configs/qwen2vl.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Qwen2VL model configuration"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVisionConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen2_vl"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth=32,
|
||||||
|
embed_dim=1280,
|
||||||
|
hidden_size=3584,
|
||||||
|
hidden_act="quick_gelu",
|
||||||
|
mlp_ratio=4,
|
||||||
|
num_heads=16,
|
||||||
|
in_channels=3,
|
||||||
|
patch_size=14,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
temporal_patch_size=2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||||
|
) -> "PretrainedConfig":
|
||||||
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
|
config_dict, kwargs = cls.get_config_dict(
|
||||||
|
pretrained_model_name_or_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_dict.get("model_type") == "qwen2_vl":
|
||||||
|
config_dict = config_dict["vision_config"]
|
||||||
|
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen2_vl"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=152064,
|
||||||
|
hidden_size=8192,
|
||||||
|
intermediate_size=29568,
|
||||||
|
num_hidden_layers=80,
|
||||||
|
num_attention_heads=64,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=1000000.0,
|
||||||
|
use_sliding_window=False,
|
||||||
|
sliding_window=4096,
|
||||||
|
max_window_layers=80,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
vision_config=None,
|
||||||
|
rope_scaling=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
||||||
|
elif vision_config is None:
|
||||||
|
self.vision_config = Qwen2VLVisionConfig()
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.use_sliding_window = use_sliding_window
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.max_window_layers = max_window_layers
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
# NOTE: the following section from original transformers config
|
||||||
|
# for Qwen2-VL is commented out to address rope config loading issue
|
||||||
|
#
|
||||||
|
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
# if self.rope_scaling["type"] == "mrope":
|
||||||
|
# self.rope_scaling["type"] = "default"
|
||||||
|
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
# rope_config_validation(self)
|
||||||
|
|
||||||
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
@@ -530,3 +530,17 @@ register_conv_template(
|
|||||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="qwen2-vl",
|
||||||
|
system_message="You are a helpful assistant.",
|
||||||
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
|
sep="<|im_end|>\n",
|
||||||
|
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||||
|
stop_str=["<|im_end|>"],
|
||||||
|
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -33,12 +33,13 @@ from transformers import (
|
|||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||||
|
|
||||||
from sglang.srt.configs import ExaoneConfig
|
from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
|
||||||
|
|
||||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||||
DbrxConfig.model_type: DbrxConfig,
|
DbrxConfig.model_type: DbrxConfig,
|
||||||
ExaoneConfig.model_type: ExaoneConfig,
|
ExaoneConfig.model_type: ExaoneConfig,
|
||||||
|
Qwen2VLConfig.model_type: Qwen2VLConfig,
|
||||||
}
|
}
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# We want this file to run without vllm dependency
|
# We want this file to run without vllm dependency
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ def _fwd_kernel(
|
|||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
Lk: tl.constexpr,
|
Lk: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
@@ -78,7 +79,9 @@ def _fwd_kernel(
|
|||||||
mask_d = offs_d < Lk
|
mask_d = offs_d < Lk
|
||||||
|
|
||||||
q = tl.load(
|
q = tl.load(
|
||||||
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
|
Q + off_q,
|
||||||
|
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
||||||
|
other=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
@@ -91,7 +94,12 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
end_n = (
|
||||||
|
cur_batch_seq_len
|
||||||
|
if not IS_CAUSAL
|
||||||
|
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
|
||||||
|
)
|
||||||
|
for start_n in range(0, block_mask * end_n, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(
|
k = tl.load(
|
||||||
@@ -104,7 +112,18 @@ def _fwd_kernel(
|
|||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
|
||||||
|
if IS_CAUSAL:
|
||||||
|
qk += tl.where(
|
||||||
|
(start_n + offs_n[None, :] < cur_batch_seq_len)
|
||||||
|
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
|
||||||
|
0,
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qk += tl.where(
|
||||||
|
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
# -- compute m_ij, p, l_ij
|
# -- compute m_ij, p, l_ij
|
||||||
m_ij = tl.max(qk, 1)
|
m_ij = tl.max(qk, 1)
|
||||||
@@ -146,7 +165,9 @@ def _fwd_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
def context_attention_fwd(
|
||||||
|
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||||
|
):
|
||||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
else:
|
else:
|
||||||
@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
|
IS_CAUSAL=is_causal,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
Lk=Lk,
|
Lk=Lk,
|
||||||
|
|||||||
145
python/sglang/srt/layers/rotary_embedding.py
Normal file
145
python/sglang/srt/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""MRotaryEmbedding"""
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MRotaryEmbedding:
|
||||||
|
"""Rotary Embedding with Multimodal Sections."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_input_positions(
|
||||||
|
input_tokens: List[int],
|
||||||
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
image_token_id: int,
|
||||||
|
video_token_id: int,
|
||||||
|
vision_start_token_id: int,
|
||||||
|
vision_end_token_id: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
context_len: int = 0,
|
||||||
|
extend_prefix_len: int = 0,
|
||||||
|
) -> Tuple[List[List[int]], int]:
|
||||||
|
"""Get mrope input positions and delta value."""
|
||||||
|
|
||||||
|
if isinstance(image_grid_thw, torch.Tensor):
|
||||||
|
image_grid_thw = image_grid_thw.tolist()
|
||||||
|
if isinstance(video_grid_thw, torch.Tensor):
|
||||||
|
video_grid_thw = video_grid_thw.tolist()
|
||||||
|
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens)
|
||||||
|
vision_start_indices = torch.argwhere(
|
||||||
|
input_tokens_tensor == vision_start_token_id
|
||||||
|
).squeeze(1)
|
||||||
|
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||||
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
|
video_nums = (vision_tokens == video_token_id).sum()
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_videos = image_nums, video_nums
|
||||||
|
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
for _ in range(image_nums + video_nums):
|
||||||
|
if image_token_id in input_tokens and remain_images > 0:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if video_token_id in input_tokens and remain_videos > 0:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t, h, w = (
|
||||||
|
video_grid_thw[video_index][0],
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
video_index += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
ed = ed_video
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
text_len = ed - st
|
||||||
|
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
torch.arange(llm_grid_t)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(llm_grid_t, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(llm_grid_t, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||||
|
)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
llm_positions = llm_positions[:, context_len:]
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
llm_positions += extend_prefix_len
|
||||||
|
|
||||||
|
return llm_positions.tolist(), mrope_position_delta
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_next_input_positions(
|
||||||
|
mrope_position_delta: int,
|
||||||
|
context_len: int,
|
||||||
|
seq_len: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
return [
|
||||||
|
list(
|
||||||
|
range(
|
||||||
|
context_len + mrope_position_delta, seq_len + mrope_position_delta
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
@@ -177,10 +177,127 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||||
|
def __init__(self, hf_config, server_args, _image_processor):
|
||||||
|
self.hf_config = hf_config
|
||||||
|
self._image_processor = _image_processor
|
||||||
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||||
|
initializer=init_global_processor,
|
||||||
|
mp_context=mp.get_context("fork"),
|
||||||
|
initargs=(server_args,),
|
||||||
|
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_single_image_task(
|
||||||
|
image_data: Union[str, bytes],
|
||||||
|
image_processor=None,
|
||||||
|
):
|
||||||
|
image_processor = image_processor or global_processor.image_processor
|
||||||
|
|
||||||
|
try:
|
||||||
|
image, image_size = load_image(image_data)
|
||||||
|
if image_size is not None:
|
||||||
|
# It is a video with multiple images
|
||||||
|
image_hash = hash(image_data)
|
||||||
|
process_result = image_processor(image)
|
||||||
|
pixel_values, image_grid_thws = (
|
||||||
|
process_result["pixel_values"],
|
||||||
|
process_result["image_grid_thw"][0],
|
||||||
|
)
|
||||||
|
for _ in range(len(pixel_values)):
|
||||||
|
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||||
|
pixel_values = np.stack(pixel_values, axis=0)
|
||||||
|
image_grid_thws = np.stack(image_grid_thws, axis=0)
|
||||||
|
return pixel_values, image_hash, image_size, image_grid_thws
|
||||||
|
else:
|
||||||
|
# It is an image
|
||||||
|
image_hash = hash(image_data)
|
||||||
|
process_result = image_processor(image)
|
||||||
|
pixel_values, image_grid_thws = (
|
||||||
|
process_result["pixel_values"],
|
||||||
|
process_result["image_grid_thw"][0],
|
||||||
|
)
|
||||||
|
if isinstance(pixel_values, np.ndarray):
|
||||||
|
pixel_values = pixel_values.astype(np.float16)
|
||||||
|
|
||||||
|
return pixel_values, image_hash, image.size, image_grid_thws
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||||
|
|
||||||
|
async def _process_single_image(self, image_data: Union[bytes, str]):
|
||||||
|
if self.executor is not None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
Qwen2VLImageProcessor._process_single_image_task,
|
||||||
|
image_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._process_single_image_task(image_data)
|
||||||
|
|
||||||
|
async def process_images_async(
|
||||||
|
self, image_data: List[Union[str, bytes]], request_obj
|
||||||
|
):
|
||||||
|
if not image_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(image_data, list) and len(image_data) > 0:
|
||||||
|
# Multiple images
|
||||||
|
if len(image_data) > 1:
|
||||||
|
pixel_values, image_hashes, image_sizes, image_grid_thws = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
res = []
|
||||||
|
for img_data in image_data:
|
||||||
|
res.append(self._process_single_image(img_data))
|
||||||
|
res = await asyncio.gather(*res)
|
||||||
|
for pixel_v, image_h, image_s, image_thw in res:
|
||||||
|
pixel_values.append(pixel_v)
|
||||||
|
image_hashes.append(image_h)
|
||||||
|
image_sizes.append(image_s)
|
||||||
|
image_grid_thws.append(image_thw)
|
||||||
|
|
||||||
|
if isinstance(pixel_values[0], np.ndarray):
|
||||||
|
pixel_values = np.concatenate(pixel_values, axis=0)
|
||||||
|
else:
|
||||||
|
# A single image
|
||||||
|
pixel_values, image_hash, image_size, image_grid_thw = (
|
||||||
|
await self._process_single_image(image_data[0])
|
||||||
|
)
|
||||||
|
image_hashes = [image_hash]
|
||||||
|
image_sizes = [image_size]
|
||||||
|
image_grid_thws = [image_grid_thw]
|
||||||
|
elif isinstance(image_data, str):
|
||||||
|
# A single image
|
||||||
|
pixel_values, image_hash, image_size, image_grid_thw = (
|
||||||
|
await self._process_single_image(image_data)
|
||||||
|
)
|
||||||
|
image_hashes = [image_hash]
|
||||||
|
image_sizes = [image_size]
|
||||||
|
image_grid_thws = [image_grid_thw]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image data: {image_data}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"image_hashes": image_hashes,
|
||||||
|
"image_sizes": image_sizes,
|
||||||
|
"modalities": request_obj.modalities,
|
||||||
|
"image_grid_thws": image_grid_thws,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(
|
def get_image_processor(
|
||||||
hf_config, server_args: ServerArgs, _image_processor
|
hf_config, server_args: ServerArgs, _image_processor
|
||||||
) -> BaseImageProcessor:
|
) -> BaseImageProcessor:
|
||||||
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||||
|
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
|
||||||
|
else:
|
||||||
|
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
||||||
|
|
||||||
|
|
||||||
def get_dummy_image_processor():
|
def get_dummy_image_processor():
|
||||||
|
|||||||
@@ -128,6 +128,8 @@ class ImageInputs:
|
|||||||
image_embeds: Optional[List[torch.Tensor]] = None
|
image_embeds: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||||
|
# QWen2-VL related
|
||||||
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(obj, vocab_size):
|
def from_dict(obj, vocab_size):
|
||||||
@@ -135,6 +137,7 @@ class ImageInputs:
|
|||||||
ret = ImageInputs(
|
ret = ImageInputs(
|
||||||
pixel_values=obj["pixel_values"],
|
pixel_values=obj["pixel_values"],
|
||||||
image_hash=hash(tuple(obj["image_hashes"])),
|
image_hash=hash(tuple(obj["image_hashes"])),
|
||||||
|
image_grid_thws=obj.get("image_grid_thws"),
|
||||||
)
|
)
|
||||||
image_hash = ret.image_hash
|
image_hash = ret.image_hash
|
||||||
ret.pad_values = [
|
ret.pad_values = [
|
||||||
@@ -236,6 +239,9 @@ class Req:
|
|||||||
self.regex_fsm_state: int = 0
|
self.regex_fsm_state: int = 0
|
||||||
self.jump_forward_map: JumpForwardMap = None
|
self.jump_forward_map: JumpForwardMap = None
|
||||||
|
|
||||||
|
# For Qwen2-VL
|
||||||
|
self.mrope_position_delta = [] # use mutable object
|
||||||
|
|
||||||
# whether request reached finished condition
|
# whether request reached finished condition
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self.finished_reason is not None
|
return self.finished_reason is not None
|
||||||
@@ -854,6 +860,8 @@ class ScheduleBatch:
|
|||||||
global bid
|
global bid
|
||||||
bid += 1
|
bid += 1
|
||||||
|
|
||||||
|
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
|
||||||
|
|
||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
bid=bid,
|
bid=bid,
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
@@ -869,6 +877,7 @@ class ScheduleBatch:
|
|||||||
image_inputs=image_inputs,
|
image_inputs=image_inputs,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
|
mrope_positions_delta=mrope_positions_delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -920,6 +929,9 @@ class ModelWorkerBatch:
|
|||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo
|
sampling_info: SamplingBatchInfo
|
||||||
|
|
||||||
|
# For Qwen2-VL
|
||||||
|
mrope_positions_delta: List[List[int]]
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
bid=self.bid,
|
bid=self.bid,
|
||||||
@@ -936,4 +948,5 @@ class ModelWorkerBatch:
|
|||||||
image_inputs=self.image_inputs,
|
image_inputs=self.image_inputs,
|
||||||
lora_paths=self.lora_paths,
|
lora_paths=self.lora_paths,
|
||||||
sampling_info=self.sampling_info.copy(),
|
sampling_info=self.sampling_info.copy(),
|
||||||
|
mrope_positions_delta=self.mrope_positions_delta,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||||
@@ -112,14 +114,88 @@ class ForwardBatch:
|
|||||||
token_to_kv_pool: BaseTokenToKVPool = None
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
attn_backend: AttentionBackend = None
|
attn_backend: AttentionBackend = None
|
||||||
|
|
||||||
|
# For Qwen2-VL
|
||||||
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
|
def compute_mrope_positions(
|
||||||
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
|
):
|
||||||
|
device = model_runner.device
|
||||||
|
hf_config = model_runner.model_config.hf_config
|
||||||
|
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||||
|
if self.forward_mode.is_decode():
|
||||||
|
for i, _ in enumerate(mrope_positions_list):
|
||||||
|
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||||
|
batch.mrope_positions_delta[i][0],
|
||||||
|
int(self.seq_lens[i]) - 1,
|
||||||
|
int(self.seq_lens[i]),
|
||||||
|
)
|
||||||
|
elif self.forward_mode.is_extend():
|
||||||
|
for i, image_inputs in enumerate(batch.image_inputs):
|
||||||
|
if image_inputs is None:
|
||||||
|
# text only
|
||||||
|
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
|
||||||
|
mrope_position_delta = 0
|
||||||
|
else:
|
||||||
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||||
|
self.extend_start_loc[i],
|
||||||
|
self.extend_seq_lens[i],
|
||||||
|
self.extend_prefix_lens[i],
|
||||||
|
)
|
||||||
|
mrope_positions, mrope_position_delta = (
|
||||||
|
MRotaryEmbedding.get_input_positions(
|
||||||
|
input_tokens=self.input_ids[
|
||||||
|
extend_start_loc : extend_start_loc + extend_seq_len
|
||||||
|
].tolist(),
|
||||||
|
image_grid_thw=image_inputs.image_grid_thws,
|
||||||
|
video_grid_thw=None,
|
||||||
|
image_token_id=hf_config.image_token_id,
|
||||||
|
video_token_id=hf_config.video_token_id,
|
||||||
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
|
vision_end_token_id=hf_config.vision_end_token_id,
|
||||||
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||||
|
context_len=0,
|
||||||
|
extend_prefix_len=extend_prefix_len.item(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mrope_positions_list[i] = mrope_positions
|
||||||
|
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
||||||
|
|
||||||
|
self.mrope_positions = torch.tensor(
|
||||||
|
np.concatenate(
|
||||||
|
[np.array(pos) for pos in mrope_positions_list],
|
||||||
|
axis=1,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||||
|
|
||||||
|
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
|
||||||
|
device = model_runner.device
|
||||||
|
if self.forward_mode.is_decode():
|
||||||
|
self.positions = (self.seq_lens - 1).to(torch.int64)
|
||||||
|
else:
|
||||||
|
self.positions = torch.tensor(
|
||||||
|
np.concatenate(
|
||||||
|
[
|
||||||
|
np.arange(prefix_len, prefix_len + extend_len)
|
||||||
|
for prefix_len, extend_len in zip(
|
||||||
|
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||||
|
)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
).to(torch.int64)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
batch: ModelWorkerBatch,
|
batch: ModelWorkerBatch,
|
||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
):
|
):
|
||||||
device = model_runner.device
|
|
||||||
|
|
||||||
|
device = model_runner.device
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
batch_size=len(batch.seq_lens),
|
batch_size=len(batch.seq_lens),
|
||||||
@@ -156,6 +232,13 @@ class ForwardBatch:
|
|||||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||||
|
|
||||||
|
# Init position information
|
||||||
|
is_mrope = model_runner.model_is_mrope
|
||||||
|
if is_mrope:
|
||||||
|
ret.compute_mrope_positions(model_runner, batch)
|
||||||
|
else:
|
||||||
|
ret.compute_positions(model_runner, batch)
|
||||||
|
|
||||||
# Init attention information
|
# Init attention information
|
||||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||||
|
|||||||
@@ -125,6 +125,11 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
server_args.chunked_prefill_size = None
|
server_args.chunked_prefill_size = None
|
||||||
server_args.mem_fraction_static *= 0.95
|
server_args.mem_fraction_static *= 0.95
|
||||||
|
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
|
||||||
|
if self.model_config.hf_config.architectures == [
|
||||||
|
"Qwen2VLForConditionalGeneration"
|
||||||
|
]:
|
||||||
|
server_args.disable_cuda_graph = True
|
||||||
|
|
||||||
# Global vars
|
# Global vars
|
||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
@@ -622,6 +627,15 @@ class ModelRunner:
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_is_mrope(self) -> bool:
|
||||||
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||||
|
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||||
|
if rope_scaling is None:
|
||||||
|
return False
|
||||||
|
return rope_scaling.get("type", None) == "mrope"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def import_model_classes():
|
def import_model_classes():
|
||||||
|
|||||||
720
python/sglang/srt/models/qwen2_vl.py
Normal file
720
python/sglang/srt/models/qwen2_vl.py
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
||||||
|
# Copyright 2024 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
|
from vllm.distributed import parallel_state
|
||||||
|
from vllm.distributed import utils as dist_utils
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.activation import QuickGELU
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
|
|
||||||
|
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||||
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||||
|
context_attention_fwd,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# === Vision Inputs === #
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLImageInputs(TypedDict):
|
||||||
|
pixel_values: torch.Tensor
|
||||||
|
"""Shape:
|
||||||
|
`(num_patches, num_channels * patch_size * patch_size)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_grid_thw: torch.Tensor
|
||||||
|
"""Shape: `(num_images, 3)`
|
||||||
|
|
||||||
|
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVideoInputs(TypedDict):
|
||||||
|
pixel_values_videos: torch.Tensor
|
||||||
|
"""Shape:
|
||||||
|
`(num_patches,
|
||||||
|
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
video_grid_thw: torch.Tensor
|
||||||
|
"""Shape: `(num_videos, 3)`
|
||||||
|
|
||||||
|
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# === Vision Encoder === #
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int = None,
|
||||||
|
act_layer: Type[nn.Module] = QuickGELU,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
in_features, hidden_features, quant_config=quant_config
|
||||||
|
)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = RowParallelLinear(
|
||||||
|
hidden_features, in_features, quant_config=quant_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x_parallel, _ = self.fc1(x)
|
||||||
|
x_parallel = self.act(x_parallel)
|
||||||
|
x, _ = self.fc2(x_parallel)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||||
|
if not interleaved:
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
else:
|
||||||
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||||
|
return rearrange(
|
||||||
|
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_torch(
|
||||||
|
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (batch_size, seqlen, nheads, headdim)
|
||||||
|
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||||
|
"""
|
||||||
|
ro_dim = cos.shape[-1] * 2
|
||||||
|
assert ro_dim <= x.shape[-1]
|
||||||
|
cos = repeat(
|
||||||
|
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||||
|
)
|
||||||
|
sin = repeat(
|
||||||
|
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||||
|
)
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||||
|
x[..., ro_dim:],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||||
|
t_ = t.float()
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: Optional[int] = None,
|
||||||
|
num_heads: Optional[int] = None,
|
||||||
|
projection_size: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# Per attention head and per partition values.
|
||||||
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
|
projection_size, num_heads
|
||||||
|
)
|
||||||
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
|
num_heads, world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv = ColumnParallelLinear(
|
||||||
|
input_size=embed_dim,
|
||||||
|
output_size=3 * projection_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.proj = RowParallelLinear(
|
||||||
|
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
|
x, _ = self.qkv(x)
|
||||||
|
|
||||||
|
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||||
|
new_x_shape = x.size()[:-1] + (
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
3 * self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
|
||||||
|
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
||||||
|
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
||||||
|
batch_size = q.shape[1]
|
||||||
|
|
||||||
|
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
|
||||||
|
if rotary_pos_emb is not None:
|
||||||
|
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||||
|
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||||
|
|
||||||
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
max_seqlen = (seq_lens).max().item()
|
||||||
|
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
||||||
|
|
||||||
|
output = torch.empty_like(q)
|
||||||
|
context_attention_fwd(
|
||||||
|
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||||
|
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||||
|
|
||||||
|
output, _ = self.proj(context_layer)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
act_layer: Type[nn.Module] = QuickGELU,
|
||||||
|
norm_layer: Type[nn.Module] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
|
||||||
|
self.attn = Qwen2VisionAttention(
|
||||||
|
embed_dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
projection_size=dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.mlp = Qwen2VisionMLP(
|
||||||
|
dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x = x + self.attn(
|
||||||
|
self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||||
|
)
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionPatchEmbed(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 1152,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||||
|
self.proj = nn.Conv3d(
|
||||||
|
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
L, C = x.shape
|
||||||
|
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||||
|
x = self.proj(x).view(L, self.embed_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionPatchMerger(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
context_dim: int,
|
||||||
|
norm_layer: Type[nn.Module] = None,
|
||||||
|
spatial_merge_size: int = 2,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
self.ln_q = norm_layer(context_dim)
|
||||||
|
self.mlp = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
RowParallelLinear(
|
||||||
|
self.hidden_size, d_model, bias=True, quant_config=quant_config
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.ln_q(x)
|
||||||
|
x = x.view(-1, self.hidden_size)
|
||||||
|
|
||||||
|
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||||
|
x_parallel, _ = mlp_fc1(x)
|
||||||
|
x_parallel = mlp_act(x_parallel)
|
||||||
|
out, _ = mlp_fc2(x_parallel)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionRotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._freqs_cached = None
|
||||||
|
|
||||||
|
def update_freqs_cache(self, seqlen: int) -> None:
|
||||||
|
if seqlen > self._seq_len_cached:
|
||||||
|
seqlen *= 2
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
self.inv_freq = 1.0 / (
|
||||||
|
self.theta
|
||||||
|
** (
|
||||||
|
torch.arange(
|
||||||
|
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
|
||||||
|
)
|
||||||
|
/ self.dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seq = torch.arange(
|
||||||
|
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
freqs = torch.outer(seq, self.inv_freq)
|
||||||
|
self._freqs_cached = freqs
|
||||||
|
|
||||||
|
def forward(self, seqlen: int) -> torch.Tensor:
|
||||||
|
self.update_freqs_cache(seqlen)
|
||||||
|
return self._freqs_cached[:seqlen]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config: Qwen2VLVisionConfig,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
patch_size: int = vision_config.patch_size
|
||||||
|
temporal_patch_size: int = vision_config.temporal_patch_size
|
||||||
|
spatial_merge_size: int = vision_config.spatial_merge_size
|
||||||
|
in_chans: int = vision_config.in_chans
|
||||||
|
hidden_size: int = vision_config.hidden_size
|
||||||
|
embed_dim: int = vision_config.embed_dim
|
||||||
|
depth: int = vision_config.depth
|
||||||
|
num_heads: int = vision_config.num_heads
|
||||||
|
mlp_ratio: float = vision_config.mlp_ratio
|
||||||
|
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
|
||||||
|
self.patch_embed = Qwen2VisionPatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
temporal_patch_size=temporal_patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen2VisionBlock(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.merger = Qwen2VisionPatchMerger(
|
||||||
|
d_model=hidden_size,
|
||||||
|
context_dim=embed_dim,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self.blocks[0].mlp.fc2.weight.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.blocks[0].mlp.fc2.weight.device
|
||||||
|
|
||||||
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
|
hpos_ids = (
|
||||||
|
hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
wpos_ids = (
|
||||||
|
wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
return rotary_pos_emb
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
grid_thw: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# patchify
|
||||||
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
# compute position embedding
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
|
||||||
|
# compute cu_seqlens
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|
||||||
|
# transformers
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||||
|
|
||||||
|
# adapter
|
||||||
|
x = self.merger(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
cached_get_processor = lru_cache(get_processor)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||||
|
processor = cached_get_processor(self.config._name_or_path)
|
||||||
|
grid_t, grid_h, grid_w = image_grid_thw
|
||||||
|
num_image_tokens = (
|
||||||
|
grid_t
|
||||||
|
* grid_h
|
||||||
|
* grid_w
|
||||||
|
// processor.image_processor.merge_size
|
||||||
|
// processor.image_processor.merge_size
|
||||||
|
)
|
||||||
|
return num_image_tokens
|
||||||
|
|
||||||
|
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||||
|
# and replaced padding by unique image hash
|
||||||
|
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||||
|
image_grid_thws = image_inputs.image_grid_thws
|
||||||
|
pad_values = image_inputs.pad_values
|
||||||
|
|
||||||
|
image_indices = [
|
||||||
|
idx
|
||||||
|
for idx, token in enumerate(input_ids)
|
||||||
|
if token == self.config.image_token_id
|
||||||
|
]
|
||||||
|
image_inputs.image_offsets = []
|
||||||
|
|
||||||
|
input_ids_with_image = []
|
||||||
|
for image_cnt, _ in enumerate(image_grid_thws):
|
||||||
|
num_image_tokens = self.calculate_num_image_tokens(
|
||||||
|
image_grid_thws[image_cnt]
|
||||||
|
)
|
||||||
|
if image_cnt == 0:
|
||||||
|
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||||
|
else:
|
||||||
|
non_image_tokens = input_ids[
|
||||||
|
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||||
|
]
|
||||||
|
input_ids_with_image.extend(non_image_tokens)
|
||||||
|
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||||
|
pad_ids = pad_values * (
|
||||||
|
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||||
|
)
|
||||||
|
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||||
|
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||||
|
|
||||||
|
return input_ids_with_image
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen2VLConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
self.visual = Qwen2VisionTransformer(
|
||||||
|
config.vision_config,
|
||||||
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
|
# NOTE: Qwen2-VL vision encoder does not support any
|
||||||
|
# quantization method now.
|
||||||
|
quant_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = Qwen2Model(config, quant_config)
|
||||||
|
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||||
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||||
|
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(
|
||||||
|
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
|
||||||
|
)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
"""Run forward pass for Qwen2-VL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
|
batch.
|
||||||
|
positions: Flattened (concatenated) position ids corresponding to a
|
||||||
|
batch.
|
||||||
|
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
||||||
|
opensource models), the shape will be `(3, seq_len)`,
|
||||||
|
otherwise it will be `(seq_len,).
|
||||||
|
(Use input_metadata.mrope_positions to replace it)
|
||||||
|
pixel_values: Pixel values to be fed to a model.
|
||||||
|
`None` if no images are passed.
|
||||||
|
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
||||||
|
`None` if no images are passed.
|
||||||
|
"""
|
||||||
|
image_inputs = None
|
||||||
|
if forward_batch.image_inputs is not None:
|
||||||
|
image_inputs = [
|
||||||
|
img for img in forward_batch.image_inputs if img is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
positions = forward_batch.mrope_positions
|
||||||
|
if image_inputs is None or len(image_inputs) == 0:
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||||
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
|
"multimodal section rotary embedding requires "
|
||||||
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||||
|
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||||
|
for i, image in enumerate(forward_batch.image_inputs):
|
||||||
|
if image == None:
|
||||||
|
continue
|
||||||
|
start_idx = extend_start_loc_cpu[i]
|
||||||
|
prefix_len = prefix_lens_cpu[i]
|
||||||
|
|
||||||
|
pixel_values = torch.tensor(image.pixel_values, device="cuda")
|
||||||
|
image_grid_thws = torch.tensor(
|
||||||
|
np.array(image.image_grid_thws), device="cuda"
|
||||||
|
)
|
||||||
|
image_offsets = image.image_offsets
|
||||||
|
image_input = Qwen2VLImageInputs(
|
||||||
|
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
||||||
|
)
|
||||||
|
image_embeds = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
image_embeds_offset = 0
|
||||||
|
for idx, image_offset in enumerate(image_offsets):
|
||||||
|
if image_offset < prefix_len:
|
||||||
|
continue
|
||||||
|
num_image_tokens = self.calculate_num_image_tokens(
|
||||||
|
image_grid_thws[idx]
|
||||||
|
)
|
||||||
|
left_idx = start_idx + (image_offset - prefix_len)
|
||||||
|
right_idx = (
|
||||||
|
start_idx + (image_offset - prefix_len) + num_image_tokens
|
||||||
|
)
|
||||||
|
inputs_embeds[left_idx:right_idx] = image_embeds[
|
||||||
|
image_embeds_offset : image_embeds_offset + num_image_tokens
|
||||||
|
]
|
||||||
|
image_embeds_offset += num_image_tokens
|
||||||
|
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
input_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if "visual" in name and "qkv.weight" in name:
|
||||||
|
visual_num_heads = self.config.vision_config.num_heads
|
||||||
|
visual_embed_dim = self.config.vision_config.embed_dim
|
||||||
|
head_size = visual_embed_dim // visual_num_heads
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
3, visual_num_heads, head_size, visual_embed_dim
|
||||||
|
)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||||
|
elif "visual" in name and "qkv.bias" in name:
|
||||||
|
visual_num_heads = self.config.vision_config.num_heads
|
||||||
|
visual_embed_dim = self.config.vision_config.embed_dim
|
||||||
|
head_size = visual_embed_dim // visual_num_heads
|
||||||
|
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(-1)
|
||||||
|
try:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
except KeyError:
|
||||||
|
print(params_dict.keys())
|
||||||
|
raise
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = Qwen2VLForConditionalGeneration
|
||||||
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
|
|||||||
or "LlavaQwenForCausalLM" in model_architectures
|
or "LlavaQwenForCausalLM" in model_architectures
|
||||||
or "LlavaMistralForCausalLM" in model_architectures
|
or "LlavaMistralForCausalLM" in model_architectures
|
||||||
or "LlavaVidForCausalLM" in model_architectures
|
or "LlavaVidForCausalLM" in model_architectures
|
||||||
|
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -344,5 +344,24 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
list(executor.map(self.run_decode_with_image, image_ids))
|
list(executor.map(self.run_decode_with_image, image_ids))
|
||||||
|
|
||||||
|
|
||||||
|
class TestQWen2VLServer(TestOpenAIVisionServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=[
|
||||||
|
"--chat-template",
|
||||||
|
"qwen2-vl",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user