Files
sglang/python/sglang/srt/models/minicpmv.py
2025-01-19 21:33:27 +08:00

1239 lines
42 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
from functools import cached_property, partial
from typing import (
Any,
Callable,
Iterable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
import torch
import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
RawImageType = Union[Image.Image, torch.Tensor]
class Idefics2VisionMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
num_heads_per_partition = divide(self.num_heads, tp_size)
self.self_attn = VisionAttention(
embed_dim=config.hidden_size,
num_heads=num_heads_per_partition,
projection_size=config.intermediate_size,
use_qkv_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config, quant_config=quant_config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states,
cu_seqlens=cu_seqlens,
# , forward_batch=forward_batch
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention
layers. Each layer is a
[`Idefics2EncoderLayer`].
Args:
config: Idefics2Config
"""
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics2EncoderLayer(
config,
quant_config=quant_config,
)
for _ in range(config.num_hidden_layers)
]
)
def forward(
self,
inputs_embeds: torch.Tensor,
cu_seqlens: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
r"""
Args:
inputs_embeds (torch.Tensor):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation.
This is useful if you want more control over how to convert
`input_ids` indices into associated vectorsthan the model's
internal embedding lookup matrix.
"""
hidden_states = inputs_embeds
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
)
hidden_states = layer_outputs
return hidden_states
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the
need to resize them to the same fixed size. In particular, we start from the
original pre-trained SigLIP model(which uses images of fixed-size square
images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
pixel_values = pixel_values.to(
device=self.patch_embedding.weight.device, dtype=target_dtype
)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionTransformer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config=config, quant_config=quant_config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def get_input_embeddings(self):
return self.embeddings
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
# 做 prefix sum 来得到 cu_seqlens注意在最前面插一个 0 作为 offset
cu_seqlens = torch.cat(
[
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
],
dim=0,
).to(tgt_sizes.device)
return cu_seqlens
def forward(
self,
pixel_values,
forward_batch: ForwardBatch,
patch_attention_mask: Optional[torch.BoolTensor] = None,
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
# forward_batch=forward_batch,
tgt_sizes=tgt_sizes,
)
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
encoder_outputs = self.encoder(
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
class MiniCPMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: List[torch.Tensor]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that the image size may vary, so we pass it as a list
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
tgt_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb.
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(
kv_dim,
embed_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_proj",
)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.do_post_projection = do_post_projection
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
self.proj = (
nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
if do_post_projection
else None
)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2_5(BaseResampler):
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
num_queries,
embed_dim,
num_heads,
kv_dim,
norm_layer,
quant_config=quant_config,
prefix=prefix,
)
self.max_size = max_size
self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights)
def _set_2d_pos_cache(
self, max_size: Tuple[int, int], device: torch.types.Device = "cpu"
) -> None:
pos_embed_arr = get_2d_sincos_pos_embed(
self.embed_dim, max_size, version=(2, 5)
)
pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(
self, tgt_sizes: torch.Tensor, device: torch.types.Device
) -> None:
max_h = tgt_sizes[:, 0].max().item()
max_w = tgt_sizes[:, 1].max().item()
assert isinstance(max_h, int) and isinstance(max_w, int)
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = (
max(max_h, self.max_size[0]),
max(max_w, self.max_size[1]),
)
self._set_2d_pos_cache(self.max_size, device)
def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
device = x.device
dtype = x.dtype
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = patch_len.max().item()
assert isinstance(max_patch_len, int)
key_padding_mask = torch.zeros(
(bs, max_patch_len), dtype=torch.bool, device=device
)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i].tolist()
pos_embed.append(
self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
) # patches * D
key_padding_mask[i, patch_len[i] :] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(
pos_embed, batch_first=True, padding_value=0.0
).permute(
1, 0, 2
) # BLD => L * B * D
x, _ = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
out = self.attn(
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
key_padding_mask=key_padding_mask,
)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)
# The old configs do not include version number
# TODO: Remove this after the HF repos are updated
if version_float is None:
if config.hidden_size == 2304 and config.query_num == 64:
return 2, 0
return 2, 5
version_str = str(version_float)
return tuple(int(x) for x in version_str.split("."))
class MiniCPMVBaseModel(nn.Module):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
"""
def __init__(
self,
*,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
# multimodal_config = config.model_config.multimodal_config
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self.config = config
# self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config=config, quant_config=quant_config)
self.vpm = self.init_vision_module(config, quant_config)
self.vision_dim = (
self.vpm.embed_dim
if self.version == (2, 0)
else self.vpm.embeddings.embed_dim
)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(
self.embed_dim, self.vision_dim, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
@cached_property
def sampler(self):
if hasattr(self.llm, "sampler"):
return self.llm.sampler
return get_sampler()
def _get_image_bounds(
self,
input_ids: torch.Tensor,
pad_values: List[int],
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds (start and end token ids) of the images
"""
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= input_ids == slice_start_id[0]
end_cond |= input_ids == slice_end_id[0]
(image_start_tokens,) = torch.where(start_cond)
image_start_tokens += 1
(image_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
if len(image_start_tokens) != len(image_end_tokens):
if (
len(image_start_tokens) + 1 == len(image_end_tokens)
and input_ids[0] in pad_values
and image_end_tokens[0] < image_start_tokens[0]
):
image_start_tokens = torch.cat(
[
torch.tensor([0], device=image_start_tokens.device),
image_start_tokens,
]
)
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_image_nums):
start_token = image_start_tokens[i]
end_token = image_end_tokens[i]
if start_token < end_token:
valid_pairs.append((start_token, end_token))
if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
forward_batch: ForwardBatch,
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(
forward_batch, image_inputs
)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)
pad_values = kwargs.pop("pad_values", None)
if image_embeds is not None:
image_bounds = self._get_image_bounds(
input_ids=input_ids,
pad_values=pad_values,
im_start_id=im_start_id,
im_end_id=im_end_id,
slice_start_id=slice_start_id,
slice_end_id=slice_end_id,
)
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}"
)
if isinstance(image_embeds, list):
image_embeds = torch.concat(image_embeds)
return MiniCPMVImageEmbeddingInputs(
image_bounds=image_bounds,
data=image_embeds,
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError(
"Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}"
)
if len(pixel_values_flat) == 0:
return None
image_bounds = self._get_image_bounds(
input_ids=input_ids,
pad_values=pad_values,
im_start_id=im_start_id,
im_end_id=im_end_id,
slice_start_id=slice_start_id,
slice_end_id=slice_end_id,
)
return MiniCPMVImagePixelInputs(
image_bounds=image_bounds.to(device=input_ids.device),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
None
]:
kwargs.update(
{
"pixel_values": (
None
if forward_batch.image_inputs is None
else [
i.pixel_values
for i in forward_batch.image_inputs
if i is not None
]
),
"tgt_sizes": (
None
if forward_batch.image_inputs is None
else [
i.tgt_sizes
for i in forward_batch.image_inputs
if i is not None
]
),
"im_start_id": forward_batch.image_inputs[0].im_start_id,
"im_end_id": forward_batch.image_inputs[0].im_end_id,
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
"pad_values": forward_batch.image_inputs[0].pad_values,
}
)
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.llm.model(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=vlm_embeddings,
)
return self.logits_processor(
input_ids, hidden_states, self.llm.lm_head, forward_batch
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="llm", connector="resampler", tower_model="vpm"
)
def init_llm(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
raise NotImplementedError
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
raise NotImplementedError
def init_resampler(
self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
raise NotImplementedError
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(
self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
) -> torch.Tensor:
raise NotImplementedError
class MiniCPMV2_6(MiniCPMVBaseModel):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config=config, quant_config=quant_config)
assert self.version == (2, 6)
def init_llm(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return Qwen2ForCausalLM(config=config, quant_config=quant_config)
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
) -> nn.Module:
model = Idefics2VisionTransformer(
config=config.vision_config, quant_config=quant_config
)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
setattr(model, "embed_dim", model.embeddings.embed_dim)
setattr(model, "patch_size", model.embeddings.patch_size)
return model
def init_resampler(
self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
)
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return vision_embedding
def get_vision_hidden_states(
self,
forward_batch: ForwardBatch,
data: MiniCPMVImageInputs,
) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros(
(B, 1, max_patches), dtype=torch.bool, device=device
)
for i in range(B):
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
forward_batch=forward_batch,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
image_inputs.im_end_id, list
):
return input_ids
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs
im_start_id = (
image_inputs.im_start_id[0].item()
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
else image_inputs.im_start_id[0]
)
im_end_id = (
image_inputs.im_end_id[0].item()
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
else image_inputs.im_end_id[0]
)
slice_start_id = (
image_inputs.slice_start_id[0].item()
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
else image_inputs.slice_start_id[0]
)
slice_end_id = (
image_inputs.slice_end_id[0].item()
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
else image_inputs.slice_end_id[0]
)
# Find all start and end positions for both types
start_indices = [
i
for i, x in enumerate(input_ids)
if x == im_start_id or x == slice_start_id
]
end_indices = [
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
]
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(
input_ids[last_idx : start_idx + 1]
) # include start token
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
# Update last_idx to after end token
last_idx = end_idx
# Add remaining tokens after last region
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
class MiniCPMV:
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
minicpmv: nn.Module
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if not hasattr(config, "version"):
version = (2, 6)
else:
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
raise ValueError("Currently, MiniCPMV only supports versions 2.6")
try:
minicpmv = instance_class(config=config, quant_config=quant_config)
self.minicpmv = minicpmv
except Exception as e:
print(f"Failed to instantiate MiniCPMV: {e}")
raise e
self.config = config
def __getattr__(self, name):
if name == "minicpmv":
return None
return getattr(self.minicpmv, name)
def __call__(self, *args, **kwargs):
return self.minicpmv(*args, **kwargs)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.minicpmv.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq~" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
# adapt to VisionAttention
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
if "sampler" in name:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# replace the name and load with customized loader
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = MiniCPMV