Model: Support Qwen 2.5 vl (#3258)
This commit is contained in:
722
python/sglang/srt/models/qwen2_5_vl.py
Normal file
722
python/sglang/srt/models/qwen2_5_vl.py
Normal file
@@ -0,0 +1,722 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
import logging
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import AutoModel, Qwen2VLConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||
|
||||
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5_VLMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int = None,
|
||||
bias: bool = True,
|
||||
hidden_act="silu",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
in_features, hidden_features, bias=bias, quant_config=quant_config
|
||||
)
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
in_features, hidden_features, bias=bias, quant_config=quant_config
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
hidden_features, in_features, bias=bias, quant_config=quant_config
|
||||
)
|
||||
self.act = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_parallel_gate, _ = self.gate_proj(x)
|
||||
x_parallel_gate = self.act(x_parallel_gate)
|
||||
x_parallel_up, _ = self.up_proj(x)
|
||||
x_parallel = x_parallel_gate * x_parallel_up
|
||||
x, _ = self.down_proj(x_parallel)
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
num_heads: int,
|
||||
hidden_act="silu",
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
attn_implementation: Optional[str] = "sdpa",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = Qwen2_5_VLMLP(
|
||||
dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
norm2 = self.norm2(x)
|
||||
mlp = self.mlp(norm2)
|
||||
x = x + mlp
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 1152,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||
self.proj = nn.Conv3d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.embed_dim)
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
context_dim: int,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
||||
self.mlp = nn.ModuleList(
|
||||
[
|
||||
ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
self.hidden_size, dim, bias=True, quant_config=quant_config
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size)
|
||||
|
||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||
x_parallel, _ = mlp_fc1(x)
|
||||
x_parallel = mlp_act(x_parallel)
|
||||
out, _ = mlp_fc2(x_parallel)
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._seq_len_cached = 0
|
||||
self._freqs_cached = None
|
||||
|
||||
def update_freqs_cache(self, seqlen: int) -> None:
|
||||
if seqlen > self._seq_len_cached:
|
||||
seqlen *= 2
|
||||
self._seq_len_cached = seqlen
|
||||
self.inv_freq = 1.0 / (
|
||||
self.theta
|
||||
** (
|
||||
torch.arange(
|
||||
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
|
||||
)
|
||||
/ self.dim
|
||||
)
|
||||
)
|
||||
seq = torch.arange(
|
||||
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
self._freqs_cached = freqs
|
||||
|
||||
def forward(self, seqlen: int) -> torch.Tensor:
|
||||
self.update_freqs_cache(seqlen)
|
||||
return self._freqs_cached[:seqlen]
|
||||
|
||||
|
||||
class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
patch_size: int = vision_config.patch_size
|
||||
temporal_patch_size: int = vision_config.temporal_patch_size
|
||||
spatial_merge_size: int = vision_config.spatial_merge_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
||||
in_chans: int = vision_config.in_chans
|
||||
hidden_size: int = vision_config.hidden_size
|
||||
depth: int = vision_config.depth
|
||||
num_heads: int = vision_config.num_heads
|
||||
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
||||
self.window_size = vision_config.window_size
|
||||
self.patch_size = vision_config.patch_size
|
||||
mlp_hidden_size: int = vision_config.intermediate_size
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=hidden_size,
|
||||
)
|
||||
|
||||
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
||||
head_dim = hidden_size // num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=hidden_size,
|
||||
intermediate_dim=mlp_hidden_size,
|
||||
num_heads=num_heads,
|
||||
hidden_act=vision_config.hidden_act,
|
||||
norm_layer=norm_layer,
|
||||
attn_implementation="sdpa",
|
||||
quant_config=quant_config,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
dim=vision_config.out_hidden_size,
|
||||
context_dim=hidden_size,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (
|
||||
self.window_size // self.spatial_merge_size // self.patch_size
|
||||
)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h, llm_grid_w = (
|
||||
grid_h // self.spatial_merge_size,
|
||||
grid_w // self.spatial_merge_size,
|
||||
)
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w
|
||||
)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||
index_padded = index_padded.reshape(
|
||||
grid_t,
|
||||
num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t,
|
||||
num_windows_h * num_windows_w,
|
||||
vit_merger_window_size,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = (
|
||||
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
)
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.blocks[0].mlp.gate_proj.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.blocks[0].mlp.gate_proj.weight.device
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = (
|
||||
hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
.flatten()
|
||||
)
|
||||
wpos_ids = (
|
||||
wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
.flatten()
|
||||
)
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
device=x.device,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
|
||||
seq_len, _ = x.size()
|
||||
|
||||
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
x = x[window_index, :, :]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
||||
)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
x = x[reverse_indices, :]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
# NOTE: Qwen2-VL vision encoder does not support any
|
||||
# quantization method now.
|
||||
quant_config=None,
|
||||
)
|
||||
|
||||
self.model = Qwen2Model(config, quant_config)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||
processor = cached_get_processor(self.config._name_or_path)
|
||||
grid_t, grid_h, grid_w = image_grid_thw
|
||||
num_image_tokens = (
|
||||
grid_t
|
||||
* grid_h
|
||||
* grid_w
|
||||
// processor.image_processor.merge_size
|
||||
// processor.image_processor.merge_size
|
||||
)
|
||||
return num_image_tokens
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = image_inputs.im_start_id
|
||||
im_end_id = image_inputs.im_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
|
||||
)
|
||||
return video_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
"""Run forward pass for Qwen2_5-VL.
|
||||
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
positions: Flattened (concatenated) position ids corresponding to a
|
||||
batch.
|
||||
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
||||
opensource models), the shape will be `(3, seq_len)`,
|
||||
otherwise it will be `(seq_len,).
|
||||
(Use input_metadata.mrope_positions to replace it)
|
||||
"""
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
image_inputs = None
|
||||
if forward_batch.image_inputs is not None:
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
or len(image_inputs) == 0
|
||||
):
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
# [B, s, hidden_size]
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
pixel_values = image.pixel_values.clone().detach().requires_grad_(False)
|
||||
image_grid_thws = torch.tensor(
|
||||
np.array(image.image_grid_thws), device="cuda"
|
||||
)
|
||||
image_offsets = image.image_offsets
|
||||
image_input = Qwen2VLImageInputs(
|
||||
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
||||
)
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
|
||||
image_embeds_offset = 0
|
||||
for idx, image_offset in enumerate(image_offsets):
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = left_idx + num_image_tokens
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
hidden_size = image_embeds.shape[-1]
|
||||
|
||||
if hidden_size % tp_size != 0:
|
||||
padding_size = tp_size - (hidden_size % tp_size)
|
||||
image_embeds = F.pad(image_embeds, (0, padding_size))
|
||||
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
|
||||
|
||||
hidden_chunk_size = image_embeds.shape[-1] // tp_size
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
start_dim = rank * hidden_chunk_size
|
||||
end_dim = (rank + 1) * hidden_chunk_size
|
||||
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
|
||||
image_embeds[
|
||||
image_embeds_offset : image_embeds_offset
|
||||
+ num_image_tokens,
|
||||
...,
|
||||
start_dim:end_dim,
|
||||
]
|
||||
)
|
||||
image_embeds_offset += num_image_tokens
|
||||
|
||||
input_ids = None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "visual" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "visual" in name and "qkv.weight" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.hidden_size
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
3, visual_num_heads, head_size, visual_embed_dim
|
||||
)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||
elif "visual" in name and "qkv.bias" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.hidden_size
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
if "visual" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
|
||||
try:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
except KeyError:
|
||||
print(params_dict.keys())
|
||||
raise
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
||||
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
|
||||
@@ -31,8 +31,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import Qwen2VLConfig
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
||||
|
||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.activation import QuickGELU
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
|
||||
Reference in New Issue
Block a user