# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py # -------------------------------------------------------- # SkyworkR1V # Copyright (c) 2025 Skywork # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import Literal, Optional, TypedDict, TypeVar, Union import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchEncoding, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) IMG_START = '' IMG_END = '' IMG_CONTEXT = '' IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) class SkyworkR1VImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values_flat: torch.Tensor """ Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" class SkyworkR1VImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: Union[torch.Tensor, list[torch.Tensor]] """ A tensor of shape `(num_images, total_image_feature_size, hidden_size)` or a list of tensors of shape `(total_image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, SkyworkR1VImageEmbeddingInputs] # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD return T.Compose([ T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ def find_closest_aspect_ratio( aspect_ratio: float, target_ratios: list[tuple[int, int]], *, width: int, height: int, image_size: int, ) -> tuple[int, int]: best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def resolve_skyworkr1v_min_max_num( *, min_dynamic_patch: int, max_dynamic_patch: int, dynamic_image_size: bool, use_thumbnail: bool, ) -> tuple[int, int]: min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if use_thumbnail and max_dynamic_patch != 1: max_dynamic_patch += 1 return min_dynamic_patch, max_dynamic_patch def get_skyworkr1v_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: target_ratios = {(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if min_num <= i * j <= max_num} return sorted(target_ratios, key=lambda x: x[0] * x[1]) def calculate_skyworkr1v_targets( *, orig_width: int, orig_height: int, target_ratios: list[tuple[int, int]], image_size: int, use_thumbnail: bool, ) -> tuple[int, int, int]: aspect_ratio = orig_width / orig_height # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, width=orig_width, height=orig_height, image_size=image_size, ) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # add thumbnail image if num_blocks != 1 if use_thumbnail and blocks != 1: blocks += 1 return blocks, target_width, target_height def dynamic_preprocess_skyworkr1v( image: Image.Image, *, target_ratios: list[tuple[int, int]], image_size: int, use_thumbnail: bool, ) -> list[Image.Image]: orig_width, orig_height = image.size # calculate the number of blocks without thumbnail blocks, target_width, target_height = calculate_skyworkr1v_targets( orig_width=orig_width, orig_height=orig_height, target_ratios=target_ratios, image_size=image_size, use_thumbnail=False, ) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B def image_to_pixel_values_skyworkr1v( image: Image.Image, *, input_size: int, min_num: int, max_num: int, use_thumbnail: bool, ) -> torch.Tensor: target_ratios = get_skyworkr1v_target_ratios(min_num, max_num) transform = build_transform(input_size=input_size) images = dynamic_preprocess_skyworkr1v( image, target_ratios=target_ratios, image_size=input_size, use_thumbnail=use_thumbnail, ) pixel_values = torch.stack([transform(image) for image in images]) return pixel_values class BaseSkyworkR1VProcessor(ABC): """ This model doesn't define its own HF processor, so we implement our own one here. The code to insert image tokens is based on: https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252 """ def __init__( self, config: PretrainedConfig, tokenizer: AnyTokenizer, *, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, ) -> None: super().__init__() self.config = config self.tokenizer = tokenizer image_size: int = config.vision_config.image_size patch_size: int = config.vision_config.patch_size if min_dynamic_patch is None: min_dynamic_patch = config.min_dynamic_patch assert isinstance(min_dynamic_patch, int) if max_dynamic_patch is None: max_dynamic_patch = config.max_dynamic_patch assert isinstance(max_dynamic_patch, int) if dynamic_image_size is None: dynamic_image_size = config.dynamic_image_size assert isinstance(dynamic_image_size, bool) self.num_image_token = int( (image_size // patch_size)**2 * (config.downsample_ratio**2)) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch self.dynamic_image_size = dynamic_image_size self.use_thumbnail: bool = config.use_thumbnail @property @abstractmethod def image_token_id(self) -> int: raise NotImplementedError @abstractmethod def get_image_repl( self, feature_size: int, num_patches: Optional[int], ) -> PromptUpdateDetails[str]: raise NotImplementedError def resolve_min_max_num( self, *, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch) max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch) dynamic_image_size = (self.dynamic_image_size if dynamic_image_size is None else dynamic_image_size) use_thumbnail = (self.use_thumbnail if use_thumbnail is None else use_thumbnail) return resolve_skyworkr1v_min_max_num( min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=use_thumbnail, ) def resolve_target_ratios( self, *, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> list[tuple[int, int]]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=use_thumbnail, ) return get_skyworkr1v_target_ratios(min_num, max_num) def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: target_ratios = self.resolve_target_ratios( use_thumbnail=False, # Applied in calculate_targets ) num_patches, _, _ = calculate_skyworkr1v_targets( orig_width=image_width, orig_height=image_height, image_size=self.image_size, target_ratios=target_ratios, use_thumbnail=self.use_thumbnail, ) return num_patches * self.num_image_token def _images_to_pixel_values_lst( self, images: list[Image.Image], min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, ) -> list[torch.Tensor]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, use_thumbnail=False, # Applied in image_to_pixel_values ) return [ image_to_pixel_values_skyworkr1v( image, input_size=self.image_size, min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, ) for image in images ] def __call__( self, text: Optional[Union[str, list[str]]] = None, images: Optional[Union[Image.Image, list[Image.Image]]] = None, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, ) -> Mapping[str, NestedTensors]: if text is None: text = [] if not isinstance(text, list): text = [text] if images is None: images = [] if not isinstance(images, list): images = [images] if len(images) == 0: image_inputs = {} else: pixel_values_lst = self._images_to_pixel_values_lst( images, min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) image_inputs: dict[str, NestedTensors] = { "pixel_values_flat": torch.cat(pixel_values_lst), "image_num_patches": torch.tensor([len(item) for item in pixel_values_lst]), } for pixel_values in pixel_values_lst: num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) text = [t.replace('', image_repl.full, 1) for t in text] text_inputs = self.tokenizer(text) return { **BatchEncoding(text_inputs, tensor_type=return_tensors), **image_inputs, } class SkyworkR1VProcessor(BaseSkyworkR1VProcessor): @property def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_CONTEXT] def get_image_repl( self, feature_size: int, num_patches: Optional[int], ) -> PromptUpdateDetails[str]: repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): @abstractmethod def get_hf_processor( self, *, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, **kwargs: object, ) -> BaseSkyworkR1VProcessor: raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_num_image_tokens( self, *, image_width: int, image_height: int, processor: Optional[BaseSkyworkR1VProcessor], ) -> int: if processor is None: processor = self.get_hf_processor() return processor.get_num_image_tokens( image_width=image_width, image_height=image_height, ) def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() base_size = processor.image_size target_ratios = processor.resolve_target_ratios() largest_feature_size, largest_feature_pinpoint = 0, None for wr, hr in target_ratios: width, height = base_size * wr, base_size * hr feat_size = self.get_num_image_tokens( image_width=width, image_height=height, processor=processor, ) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") return largest_feature_pinpoint _I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo) class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return "" * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, ) hf_processor = self.info.get_hf_processor(**mm_kwargs) image_token_id = hf_processor.image_token_id # Since there may be extra tokens in the feature placeholders, # we need to pass the image token ID to the model to select the # tokens to merge from the vision encoder outputs processed_outputs["image_token_id"] = torch.tensor(image_token_id) return processed_outputs def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) num_images = len(image_num_patches) return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if "image_num_patches" in out_mm_kwargs: image_num_patches = out_mm_kwargs["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() elif "image_embeds" in out_mm_kwargs: # TODO: Use image size information in dictionary embedding inputs # to compute num_patches (similar to Qwen2-VL) image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) else: image_num_patches = [] def get_replacement_skyworkr1v(item_idx: int): images = mm_items.get_items( "image", (ImageEmbeddingItems, ImageProcessorItems)) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) feature_size = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, processor=hf_processor, ) num_patches = image_num_patches[item_idx] if num_patches is not None: assert isinstance(num_patches, int) return hf_processor.get_image_repl(feature_size, num_patches) return [ PromptReplacement( modality="image", target="", replacement=get_replacement_skyworkr1v, ) ] class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo): def get_hf_processor( self, *, min_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, **kwargs: object, ) -> SkyworkR1VProcessor: if min_dynamic_patch is not None: kwargs["min_dynamic_patch"] = min_dynamic_patch if max_dynamic_patch is not None: kwargs["max_dynamic_patch"] = max_dynamic_patch if dynamic_image_size is not None: kwargs["dynamic_image_size"] = dynamic_image_size return self.ctx.init_processor( SkyworkR1VProcessor, config=self.get_hf_config(), tokenizer=self.get_tokenizer(), **kwargs, ) @MULTIMODAL_REGISTRY.register_processor( SkyworkR1VMultiModalProcessor, info=SkyworkR1VProcessingInfo, dummy_inputs=SkyworkR1VDummyInputsBuilder) class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( (image_size // patch_size)**2 * (config.downsample_ratio**2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.llm_arch_name = config.text_config.architectures[0] self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM' self.vision_model = self._init_vision_model( config, quant_config=quant_config, is_mono=self.is_mono, prefix=maybe_prefix(prefix, "vision_model"), ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) self.mlp1 = self._init_mlp1(config) self.img_context_token_id = None self.visual_token_mask = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) def _patch_quant_config(self, config: PretrainedConfig, quant_config: QuantizationConfig): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config llm_quant_config = getattr(text_config, "quantization_config", None) if (not quant_config.modules_to_not_convert) and \ (llm_quant_config is not None): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], *, is_mono: bool, prefix: str, ): if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: num_hidden_layers = config.vision_config.num_hidden_layers \ + vision_feature_layer + 1 else: num_hidden_layers = vision_feature_layer + 1 return InternVisionModel( config.vision_config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, ) else: return InternVisionPatchModel(config.vision_config) def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), ReplicatedLinear(vit_hidden_size * int(1 / self.downsample_ratio)**2, llm_hidden_size, return_bias=False), nn.GELU(), ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False), ) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if self.ps_version == 'v1': pass else: x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] h = w = int(vit_embeds.shape[1]**0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: expected_expr = str(expected_dims) raise ValueError( "The expected shape of pixel values per image per batch " f" per patch is {expected_expr}. " f"You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values_flat is None and image_embeds is None: return None if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return SkyworkR1VImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds), ) image_token_id = kwargs["image_token_id"] assert isinstance(image_token_id, torch.Tensor) self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: if not isinstance(pixel_values_flat, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values_flat)}") if not isinstance(image_num_patches, (torch.Tensor, list)): raise ValueError("Incorrect type of image_num_patches. " f"Got type: {type(image_num_patches)}") pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) return SkyworkR1VImagePixelInputs( type="pixel_values", pixel_values_flat=self._validate_pixel_values( pixel_values_flat), num_patches=image_num_patches, ) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: SkyworkR1VImageInputs, ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None image_embeds = self.extract_feature(image_input["pixel_values_flat"]) num_patches = image_input["num_patches"] # Only one image in the current batch if len(num_patches) == 1: return image_embeds.view( -1, self.config.text_config.hidden_size).unsqueeze(0) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] return image_embeds.split(image_feature_sizes) def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: self.visual_token_mask = ( input_ids == self.img_context_token_id).reshape(-1, 1) else: self.visual_token_mask = None def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None return self._process_image_input(image_input) def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: assert self.img_context_token_id is not None self._set_visual_token_mask(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, self.img_context_token_id, ) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: input_ids = None inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None forward_kwargs = { "input_ids": input_ids, "positions": positions, "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } # Only required if the model is mono-architecture if self.visual_token_mask is not None: forward_kwargs.update( {"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [ "action_embed", "temporal_embed", "track_embed", "track_embed_decoder", "box_token", "cg_criterion", "cg_model", "loc_encoder", "loc_decoder", "sam", "temporal_token", "track_token" ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights)