# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py # -------------------------------------------------------- # InternVL # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import re from functools import cached_property, partial from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.quantization import (AWQConfig, QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) IMG_START = '' IMG_END = '' IMG_CONTEXT = '' IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """ Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ InternVLImageInputs = Union[InternVLImagePixelInputs, InternVLImageEmbeddingInputs] # copied from https://huggingface.co/OpenGVLab/InternVL2-1B def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform # copied from https://huggingface.co/OpenGVLab/InternVL2-1B def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, max_num: int, image_size: int, use_thumbnail: bool) -> Tuple[int, int, int]: aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # add thumbnail image if num_blocks > 1 if use_thumbnail and blocks > 1: blocks += 1 return blocks, target_width, target_height def calculate_num_blocks_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch min_num = hf_config.min_dynamic_patch image_size = hf_config.vision_config.image_size use_thumbnail = hf_config.use_thumbnail return partial(calculate_num_blocks, min_num=min_num, max_num=max_dynamic_patch, image_size=image_size, use_thumbnail=use_thumbnail) # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, image_size: int, use_thumbnail: bool) -> List[Image.Image]: orig_width, orig_height = image.size # calculate the number of blocks without thumbnail blocks, target_width, target_height = calculate_num_blocks( orig_width, orig_height, min_num, max_num, image_size, use_thumbnail=False) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, max_num: int, use_thumbnail: bool) -> torch.Tensor: transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, min_num=min_num, max_num=max_num, image_size=input_size, use_thumbnail=use_thumbnail) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): image_size = hf_config.vision_config.image_size min_num = hf_config.min_dynamic_patch if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail return partial(image_to_pixel_values, input_size=image_size, min_num=min_num, max_num=max_dynamic_patch, use_thumbnail=use_thumbnail) def get_internvl_num_patches(hf_config: PretrainedConfig): vision_config = hf_config.vision_config downsample_ratio = hf_config.downsample_ratio image_size = vision_config.image_size patch_size = vision_config.patch_size return int( get_clip_num_patches(image_size=image_size, patch_size=patch_size) * (downsample_ratio**2)) def get_max_internvl_image_tokens(ctx: InputContext, *, max_dynamic_patch: Optional[int] = None): hf_config = ctx.get_hf_config() if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail if use_thumbnail and max_dynamic_patch > 1: max_dynamic_patch += 1 num_patches = get_internvl_num_patches(hf_config) return num_patches * max_dynamic_patch def get_max_internvl_image_size(ctx: InputContext, *, max_dynamic_patch: Optional[int] = None): hf_config = ctx.get_hf_config() image_size = hf_config.vision_config.image_size if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail if use_thumbnail and max_dynamic_patch > 1: max_dynamic_patch += 1 width = image_size * max_dynamic_patch height = image_size return width, height class InternVLInputPipeline: def __init__( self, img_start_token: str, img_end_token: str, img_context_token: str, ) -> None: super().__init__() self.img_start_token = img_start_token self.img_end_token = img_end_token self.img_context_token = img_context_token def _create_image_prompt(self, feature_size: int, num_patches: int) -> str: return (self.img_start_token + self.img_context_token * feature_size + self.img_end_token) def _expand_image_prompt( self, prompt: str, feature_sizes: List[int], num_patches: int, ) -> str: image_idx = sorted( map(int, re.findall(r"Image-(\d+): \n", prompt))) new_prompt = prompt for idx, feature_size in enumerate(feature_sizes, start=1): image_prompt = self._create_image_prompt(feature_size, num_patches) if not image_idx: image_prompt = f"Image-{idx}: {image_prompt}" new_prompt = new_prompt.replace('', image_prompt, 1) return new_prompt def input_processor( self, ctx: InputContext, inputs: DecoderOnlyInputs, *, max_dynamic_patch: Optional[int] = None, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config() image_data = multi_modal_data["image"] num_patches = get_internvl_num_patches(hf_config) num_blocks_calculator = calculate_num_blocks_wrapper( hf_config, max_dynamic_patch) if isinstance(image_data, Image.Image): width, height = image_data.size num_blocks, _, _ = num_blocks_calculator(width, height) image_feature_sizes = [num_blocks * num_patches] elif is_list_of(image_data, Image.Image): image_feature_sizes = [] for image in image_data: width, height = image.size num_blocks, _, _ = num_blocks_calculator(width, height) image_feature_sizes.append(num_blocks * num_patches) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape image_feature_sizes = [image_feature_size] else: raise TypeError(f"Invalid image type: {type(image_data)}") tokenizer = cached_get_tokenizer( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) prompt = inputs.get("prompt") prompt_token_ids = inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) return token_inputs(prompt=prompt, prompt_token_ids=new_prompt_token_ids, multi_modal_data=multi_modal_data) def input_mapper( self, ctx: InputContext, data: object, *, max_dynamic_patch: Optional[int] = None, ): hf_config = ctx.get_hf_config() image_pixel_values_mapper = image_to_pixel_values_wrapper( hf_config, max_dynamic_patch) if isinstance(data, Image.Image): data = image_pixel_values_mapper(data) # Add an N dimension for number of images per prompt (currently 1). data = data.unsqueeze(0) elif is_list_of(data, Image.Image): # we can't stack here because images may have different num_patches data = [image_pixel_values_mapper(img) for img in data] else: return MultiModalKwargs({"image_embeds": data}) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) image_token_id = tokenizer.encode(self.img_context_token, add_special_tokens=False, return_tensors="pt")[0] return MultiModalKwargs({ "pixel_values": data, "image_token_id": image_token_id }) def dummy_data( self, ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], *, max_dynamic_patch: Optional[int] = None, ): num_images = mm_counts["image"] hf_config = ctx.get_hf_config() image_feature_size = get_max_internvl_image_tokens( ctx, max_dynamic_patch=max_dynamic_patch) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) seq_data, ranges = dummy_seq_data_for_clip( hf_config.vision_config, seq_len, num_images, image_token_id=tokenizer.encode(self.img_context_token, add_special_tokens=False)[0], image_feature_size_override=image_feature_size, ) max_image_width, max_image_height = get_max_internvl_image_size( ctx, max_dynamic_patch=max_dynamic_patch) mm_data = dummy_image_for_clip( hf_config.vision_config, num_images, image_width_override=max_image_width, image_height_override=max_image_height, ) return DummyData(seq_data, mm_data, ranges) input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) @INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( (image_size // patch_size)**2 * (config.downsample_ratio**2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.llm_arch_name = config.text_config.architectures[0] self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' self.vision_model = self._init_vision_model( config, quant_config=quant_config, is_mono=self.is_mono, prefix=maybe_prefix(prefix, "vision_model"), ) self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")) self.mlp1 = self._init_mlp1(config) self.img_context_token_id = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) def _patch_quant_config(self, config: PretrainedConfig, quant_config: QuantizationConfig): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config llm_quant_config = getattr(text_config, "quantization_config", None) if (not quant_config.modules_to_not_convert) and \ (llm_quant_config is not None): quant_config.modules_to_not_convert.append("vision_model") @cached_property def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler return get_sampler() def _init_vision_model( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], *, is_mono: bool, prefix: str, ): if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: num_hidden_layers = config.vision_config.num_hidden_layers \ + vision_feature_layer + 1 else: num_hidden_layers = vision_feature_layer + 1 return InternVisionModel( config.vision_config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, ) else: return InternVisionPatchModel(config.vision_config) def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size), ) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if self.ps_version == 'v1': pass else: x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] h = w = int(vit_embeds.shape[1]**0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: expected_expr = str(expected_dims) raise ValueError( "The expected shape of pixel values per image per batch " f" per patch is {expected_expr}. " f"You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_token_id = kwargs.pop("image_token_id", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return InternVLImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds), ) self.img_context_token_id = image_token_id[0] if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") # We need to flatten (B, N, P) to (B*N*P), # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( flatten_bn(flatten_bn(pixel_values), concat=True)), ) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: InternVLImageInputs, ) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None image_embeds = self.extract_feature(image_input["data"]) return image_embeds def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: if self.is_mono: visual_token_mask = ( input_ids == self.img_context_token_id).reshape(-1, 1) else: visual_token_mask = None return visual_token_mask def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: if intermediate_tensors is not None: input_ids = None inputs_embeds = None visual_token_mask = None else: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) vision_embeddings = self._process_image_input(image_input) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.img_context_token_id) visual_token_mask = self._get_visual_token_mask(input_ids) input_ids = None else: inputs_embeds = None visual_token_mask = None forward_kwargs = { "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } if self.is_mono: forward_kwargs.update({"visual_token_mask": visual_token_mask}) hidden_states = self.language_model.model(**forward_kwargs) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loader.load_weights(weights)