Support InternVL3 (#5350)
Co-authored-by: Mick <mickjagger19@icloud.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
670
python/sglang/srt/models/internvl.py
Normal file
670
python/sglang/srt/models/internvl.py
Normal file
@@ -0,0 +1,670 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================582====================================================
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
||||
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_janus_pro import DropPath
|
||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class FlashAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv,
|
||||
causal=False,
|
||||
max_s=None,
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
if unpadded: (nnz, 3, h, d)
|
||||
"""
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
|
||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
||||
if batch_size == 0 or seqlen == 0:
|
||||
output_shape = (batch_size, seqlen, nheads, d)
|
||||
return (
|
||||
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
|
||||
None,
|
||||
)
|
||||
|
||||
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
|
||||
q, k, v = qkv_reshaped.unbind(1)
|
||||
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen,
|
||||
step=seqlen,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output_reshaped = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
|
||||
return output, None
|
||||
|
||||
|
||||
class InternAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
||||
self.proj_drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.inner_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _flash_attn(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
qkv = self.qkv(x)
|
||||
qkv = rearrange(
|
||||
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
|
||||
)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k, v = qkv.unbind(2)
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
|
||||
context, _ = self.inner_attn(
|
||||
qkv,
|
||||
)
|
||||
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self._flash_attn(hidden_states)
|
||||
return x
|
||||
|
||||
|
||||
class InternVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, self.embed_dim),
|
||||
)
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_positions, self.embed_dim)
|
||||
)
|
||||
|
||||
def _get_pos_embed(self, pos_embed, H, W):
|
||||
target_dtype = pos_embed.dtype
|
||||
pos_embed = (
|
||||
pos_embed.float()
|
||||
.reshape(
|
||||
1,
|
||||
self.image_size // self.patch_size,
|
||||
self.image_size // self.patch_size,
|
||||
-1,
|
||||
)
|
||||
.permute(0, 3, 1, 2)
|
||||
)
|
||||
pos_embed = (
|
||||
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
||||
.reshape(1, -1, H * W)
|
||||
.permute(0, 2, 1)
|
||||
.to(target_dtype)
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values
|
||||
) # shape = [*, channel, width, height]
|
||||
batch_size, _, height, width = patch_embeds.shape
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
position_embedding = torch.cat(
|
||||
[
|
||||
self.position_embedding[:, :1, :],
|
||||
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
embeddings = embeddings + position_embedding.to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
class InternRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class InternMLP(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
NORM2FN = {
|
||||
"rms_norm": InternRMSNorm,
|
||||
"layer_norm": nn.LayerNorm,
|
||||
}
|
||||
|
||||
|
||||
class InternVisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
drop_path_rate: float,
|
||||
quant_config: QuantizationConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
self.attn = InternAttention(config)
|
||||
self.mlp = InternMLP(config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
||||
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
||||
self.drop_path1 = (
|
||||
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
)
|
||||
self.drop_path2 = (
|
||||
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
Optional[torch.FloatTensor],
|
||||
Optional[Tuple[torch.FloatTensor]],
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
"""
|
||||
hidden_states = hidden_states + self.drop_path1(
|
||||
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.drop_path2(
|
||||
self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`InternEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config (`InternConfig`):
|
||||
The corresponding vision configuration for the `InternEncoder`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# stochastic depth decay rule
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
|
||||
]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
InternVisionEncoderLayer(config, dpr[idx], quant_config)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states
|
||||
)
|
||||
|
||||
|
||||
class InternVisionModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_supports_flash_attn_2 = True
|
||||
config_class = PretrainedConfig
|
||||
_no_split_modules = ["InternVisionEncoderLayer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(
|
||||
config,
|
||||
)
|
||||
self.encoder = InternVisionEncoder(config, quant_config)
|
||||
|
||||
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
||||
pos_emb = self.embeddings.position_embedding
|
||||
_, num_positions, embed_dim = pos_emb.shape
|
||||
cls_emb = pos_emb[:, :1, :]
|
||||
pos_emb = (
|
||||
pos_emb[:, 1:, :]
|
||||
.reshape(1, old_size // patch_size, old_size // patch_size, -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
)
|
||||
pos_emb = F.interpolate(
|
||||
pos_emb.float(),
|
||||
size=new_size // patch_size,
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
||||
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
||||
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
||||
self.embeddings.image_size = new_size
|
||||
logger.info(
|
||||
"Resized position embeddings from {} to {}".format(old_size, new_size)
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
pixel_values = pixel_values.to(device=self.device, dtype=self.dtype)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if pixel_values is None and pixel_embeds is None:
|
||||
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
||||
|
||||
if pixel_embeds is not None:
|
||||
hidden_states = pixel_embeds
|
||||
else:
|
||||
if len(pixel_values.shape) == 4:
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = encoder_outputs.last_hidden_state
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class InternVLChatModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_flash_attn=True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
self.select_layer = config.select_layer
|
||||
self.template = config.template
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
)
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
||||
config.llm_config._attn_implementation = (
|
||||
"flash_attention_2" if use_flash_attn else "eager"
|
||||
)
|
||||
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"ps_version: {self.ps_version}")
|
||||
|
||||
self.vision_model = InternVisionModel(config.vision_config)
|
||||
if config.llm_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
self.language_model = Qwen2ForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
elif config.llm_config.architectures[0] == "InternLM2ForCausalLM":
|
||||
self.language_model = InternLM2ForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{config.llm_config.architectures[0]} is not implemented."
|
||||
)
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.llm_config.hidden_size
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
|
||||
),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(
|
||||
n,
|
||||
int(h * scale_factor),
|
||||
int(w * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)),
|
||||
)
|
||||
if self.ps_version == "v1":
|
||||
logger.warn(
|
||||
"In ps_version 'v1', the height and width have not been swapped back, "
|
||||
"which results in a transposed image."
|
||||
)
|
||||
else:
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
).last_hidden_state
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
vit_embeds = vit_embeds[:, 1:, :]
|
||||
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
pixel_values = torch.cat([item.pixel_values for item in items])
|
||||
image_features = self.extract_feature(pixel_values)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hs = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = mm_inputs.im_start_id
|
||||
im_end_id: int = mm_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
return helper.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w1", 0),
|
||||
("gate_up_proj", "w3", 1),
|
||||
]
|
||||
elif "Qwen2ForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "wqkv" in name:
|
||||
config = self.config
|
||||
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
||||
)
|
||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
||||
wq = wq.reshape(-1, wq.shape[-1])
|
||||
wk = wk.reshape(-1, wk.shape[-1])
|
||||
wv = wv.reshape(-1, wv.shape[-1])
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, wq, "q")
|
||||
weight_loader(param, wk, "k")
|
||||
weight_loader(param, wv, "v")
|
||||
else:
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = InternVLChatModel
|
||||
Reference in New Issue
Block a user