# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2025 The vLLM team. # Copyright 2025 IBM. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `Terratorch` models""" from collections import OrderedDict from collections.abc import Iterable, Mapping, Sequence from typing import Any, Callable, Optional, Union import torch import torch.nn as nn from terratorch.vllm import (DummyDataGenerator, InferenceRunner, InputDefinition, InputTypeEnum) from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems, MultiModalUUIDDict, PlaceholderRange) from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal) from .interfaces_base import default_pooling_type def _terratorch_field_names(pretrained_cfg: dict): input_definition = InputDefinition(**pretrained_cfg["input"]) return set(input_definition.data.keys()) def _terratorch_field_factory( pretrained_cfg: dict ) -> Callable[ [Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig], ]: def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): input_definition = InputDefinition(**pretrained_cfg["input"]) fields = {} for input_name, input in input_definition.data.items(): if input.type == InputTypeEnum.tensor: fields[input_name] = "image" mm_fields_config = {} for field_name, field_modality in fields.items(): mm_fields_config[field_name] = MultiModalFieldConfig.shared( batch_size=1, modality=field_modality) return mm_fields_config return _terratorch_field_config class TerratorchProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): def __init__(self, info: TerratorchProcessingInfo): super().__init__(info) self.dummy_data_generator = DummyDataGenerator( self.info.get_hf_config().to_dict()["pretrained_cfg"]) def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: # Dummy data is generated based on the 'input' section # defined in the HF configuration file return self.dummy_data_generator.get_dummy_mm_data() class TerratorchMultiModalDataParser(MultiModalDataParser): def __init__(self, pretrained_cfg: dict, *args, **kwargs): self._pretrained_cfg = pretrained_cfg super().__init__(*args, **kwargs) def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): terratorch_fields = _terratorch_field_names(self._pretrained_cfg) return DictEmbeddingItems( data, modality="image", required_fields=terratorch_fields, fields_factory=_terratorch_field_factory(self._pretrained_cfg), ) return super()._parse_image_data(data) class TerratorchMultiModalProcessor(BaseMultiModalProcessor): def __init__( self, info: TerratorchProcessingInfo, dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", *, cache: Optional[MultiModalProcessorOnlyCache] = None) -> None: self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) def _get_data_parser(self) -> MultiModalDataParser: return TerratorchMultiModalDataParser( pretrained_cfg=self.pretrained_cfg) def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: return [] def apply( self, prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: if "image" in mm_data: image_data = mm_data["image"] else: image_data = mm_data mm_data = {"image": mm_data} mm_items = self._to_mm_items(mm_data) tokenization_kwargs = tokenization_kwargs or {} mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids) mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_processed_data = BatchFeature(image_data) mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_processed_data, self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), ) return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, mm_placeholders=mm_placeholders, ) @default_pooling_type("All") @MULTIMODAL_REGISTRY.register_processor( TerratorchMultiModalProcessor, info=TerratorchProcessingInfo, dummy_inputs=TerratorchInputBuilder, ) class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): supports_multimodal_raw_input_only = True is_pooling_model = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): return None raise ValueError("Only image modality is supported") def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"] self.inference_runner = InferenceRunner(config) self.model = self.inference_runner.model pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( {"encode": Pooler.for_encode(pooler_config)}, ) def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: # We do not really use any input tokens and therefore no embeddings # to be calculated. However, due to the mandatory token ids in # the input prompt we pass one token and the size of the dummy # embedding tensors must reflect that. return torch.empty((input_ids.shape[0], 0)) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): model_output = self.inference_runner.forward(**kwargs) return model_output.output def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_list = [] model_buffers = dict(self.named_buffers()) loaded_buffers = [] for key, value in weights: if isinstance(value, (dict, OrderedDict)): if key == "state_dict": weights_to_parse = value for name, weight in weights_to_parse.items(): name = f"inference_runner.{name}" if "pos_embed" in name: continue if "_timm_module." in name: name = name.replace("_timm_module.", "") # this model requires a couple of buffers to be loaded # that are not loadable with the AutoWeightsLoader if name in model_buffers: if "_timm_module." in name: name = name.replace("_timm_module.", "") buffer = model_buffers[name] weight_loader = getattr(buffer, "weight_loader", default_weight_loader) weight_loader(buffer, weight) loaded_buffers.append(name) else: params_list.append((name, weight)) break elif isinstance(value, torch.Tensor): params_list.append((f"inference_runner.model.{key}", value)) # Load the remaining model parameters loader = AutoWeightsLoader(self) autoloaded_weights = loader.load_weights(params_list) return autoloaded_weights.union(set(loaded_buffers))