diff --git a/docs/backend/vlm_query.ipynb b/docs/backend/vlm_query.ipynb index 7aba8dfb8..519811f75 100644 --- a/docs/backend/vlm_query.ipynb +++ b/docs/backend/vlm_query.ipynb @@ -132,7 +132,7 @@ "\n", "mm_item = dict(\n", " modality=\"IMAGE\",\n", - " image_grid_thws=processed_prompt[\"image_grid_thw\"],\n", + " image_grid_thw=processed_prompt[\"image_grid_thw\"],\n", " precomputed_features=precomputed_features,\n", ")\n", "out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n", diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index fa2e4d824..dae9c2b75 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -5,7 +5,8 @@ import multiprocessing as mp import os import re from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple, Union +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -16,16 +17,24 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.utils import encode_video, load_audio, load_image +class MultimodalInputFormat(Enum): + """Enum for different multimodal input formats.""" + + RAW_IMAGES = "raw_images" + PRECOMPUTED_FEATURES = "precomputed_features" + PIXEL_VALUES = "pixel_values" + + @dataclasses.dataclass class BaseMultiModalProcessorOutput: # input_text, with each frame of video/image represented with a image_token input_text: str # frames loaded from image and video, in given order - images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None + images: Optional[list[Union[Image.Image, dict]]] = None # audios - audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None + audios: Optional[list[Union[np.ndarray, dict]]] = None def normalize(self): for field_name in ["images", "audios"]: @@ -170,8 +179,6 @@ class BaseMultimodalProcessor(ABC): ): """Static method that can be pickled for multiprocessing""" if isinstance(data, dict): - return MultimodalDataItem.from_dict(data) - if isinstance(data, MultimodalDataItem): return data try: if is_audio: @@ -370,29 +377,180 @@ class BaseMultimodalProcessor(ABC): return list(zip(indices_start.tolist(), indices_end.tolist())) - def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]): - """Returns true if all images are preprocessed, false if all are not, and error otherwise.""" - if not mm_inputs: - return True - ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs) - if ret and not all( - isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs - ): - raise ValueError( - "Unsupported: mixture of multimodal inputs where some but not all are preprocessed." - ) - return ret - @staticmethod def _extract_processor_features( - items: List[Any], attr_name: str + items: List[dict], attr_name: str ) -> Optional[torch.Tensor]: """ Helper function to concat extracted attributes from processor output. """ - values = [ - getattr(item, attr_name) - for item in items - if getattr(item, attr_name) is not None - ] - return torch.concat(values) if values else None + values = [value for item in items if (value := item.get(attr_name)) is not None] + return torch.cat(values) if values else None + + # When we assume that all the items have the same attributes + def _extract_processor_features_from_all_attributes( + self, items: List[dict] + ) -> dict: + values = {} + # Verify all items have the same keys + first_keys = set(items[0].keys()) + for item in items[1:]: + if set(item.keys()) != first_keys: + raise ValueError( + f"All items must have the same attributes. " + f"First item has {first_keys}, but found {set(item.keys())}" + ) + + # Process each attribute + for k, v in items[0].items(): + if isinstance(v, list): + values[k] = self._extract_processor_features(items, k) + else: + # Verify all items have the same value for non-list attributes + for item in items[1:]: + if item[k] != v: + raise ValueError( + f"All items must have the same value for attribute {k}. " + f"First item has {v}, but found {item[k]}" + ) + values[k] = v + return values + + def process_and_combine_mm_data( + self, base_output: BaseMultiModalProcessorOutput + ) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]: + """ + Process multimodal data and return the combined multimodal item and input_ids. + Handles all three input formats at the same abstraction level. + + Returns: + Tuple of (combined_mm_item, input_ids) + """ + + def tokenize_text(input_text: str) -> torch.Tensor: + """Tokenize input text.""" + return self._processor.tokenizer( + input_text, + return_tensors="pt", + add_special_tokens=True, + ).input_ids.flatten() + + def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat: + """Categorize multimodal inputs and validate consistency.""" + try: + has_image = False + has_pixel_values = False + has_precomputed_features = False + + for mm_input in mm_inputs: + if isinstance(mm_input, Image.Image): + has_image = True + elif isinstance(mm_input, dict): + if mm_input.get("precomputed_features", None) is not None: + has_precomputed_features = True + elif mm_input.get("pixel_values", None) is not None: + has_pixel_values = True + else: + raise ValueError( + f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features" + ) + else: + raise ValueError( + f"Invalid multimodal input: {mm_input}, expected Image.Image or dict" + ) + + # Validate format consistency + format_count = sum( + [has_image, has_pixel_values, has_precomputed_features] + ) + if format_count > 1: + raise ValueError( + "Unsupported: mixture of multimodal input formats. " + f"Found formats: image={has_image}, pixel_values={has_pixel_values}, " + f"precomputed_features={has_precomputed_features}" + ) + + if has_image: + return MultimodalInputFormat.RAW_IMAGES + elif has_precomputed_features: + return MultimodalInputFormat.PRECOMPUTED_FEATURES + elif has_pixel_values: + return MultimodalInputFormat.PIXEL_VALUES + else: + raise ValueError("No valid multimodal input format found") + except Exception as e: + raise ValueError(f"Failed to categorize inputs: {e}") + + def process_raw_images( + base_output: BaseMultiModalProcessorOutput, + ) -> Tuple[MultimodalDataItem, torch.Tensor]: + """Process raw Image.Image objects using transformers processor.""" + ret = self.process_mm_data( + input_text=base_output.input_text, + images=base_output.images, + ) + combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE) + + # Copy all fields from processor output except input_ids + for key, value in ret.items(): + if key != "input_ids" and hasattr(combined_mm_item, key): + setattr(combined_mm_item, key, value) + + input_ids = ret["input_ids"].flatten() + return combined_mm_item, input_ids + + def process_precomputed_features( + base_output: BaseMultiModalProcessorOutput, + ) -> Tuple[MultimodalDataItem, torch.Tensor]: + """Process inputs with precomputed features.""" + combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE) + combined_mm_item.precomputed_features = self._extract_processor_features( + base_output.images, "precomputed_features" + ) + input_ids = tokenize_text(base_output.input_text) + return combined_mm_item, input_ids + + def process_pixel_values( + base_output: BaseMultiModalProcessorOutput, + ) -> Tuple[MultimodalDataItem, torch.Tensor]: + """Process inputs with pixel values.""" + values = self._extract_processor_features_from_all_attributes( + base_output.images + ) + combined_mm_item = MultimodalDataItem.from_dict(values) + input_ids = tokenize_text(base_output.input_text) + return combined_mm_item, input_ids + + def finalize_mm_item( + combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor + ) -> MultimodalDataItem: + """Apply common post-processing to the multimodal item.""" + combined_mm_item.image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.IM_TOKEN_ID, + ) + return combined_mm_item + + # Main logic + mm_inputs = base_output.images + if not mm_inputs: + # Return text-only case + input_ids = tokenize_text(base_output.input_text) + return None, input_ids + + # Categorize input formats + input_format = categorize_mm_inputs(mm_inputs) + + # Process based on format + if input_format == MultimodalInputFormat.RAW_IMAGES: + combined_mm_item, input_ids = process_raw_images(base_output) + elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES: + combined_mm_item, input_ids = process_precomputed_features(base_output) + elif input_format == MultimodalInputFormat.PIXEL_VALUES: + combined_mm_item, input_ids = process_pixel_values(base_output) + else: + raise ValueError(f"Unknown input format: {input_format}") + + # Finalize with common processing + combined_mm_item = finalize_mm_item(combined_mm_item, input_ids) + return combined_mm_item, input_ids diff --git a/python/sglang/srt/managers/multimodal_processors/gemma3.py b/python/sglang/srt/managers/multimodal_processors/gemma3.py index 1f7846ba9..9e28284bb 100644 --- a/python/sglang/srt/managers/multimodal_processors/gemma3.py +++ b/python/sglang/srt/managers/multimodal_processors/gemma3.py @@ -27,6 +27,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index + self.IM_TOKEN_ID = hf_config.image_token_index async def process_mm_data_async( self, @@ -42,49 +43,21 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): if isinstance(image_data, str): image_data = [image_data] - image_token = self.IMAGE_TOKEN - image_token_regex = self.IMAGE_TOKEN_REGEX base_output = self.load_mm_data( prompt=input_text, image_data=image_data, multimodal_tokens=MultimodalSpecialTokens( - image_token=image_token, image_token_regex=image_token_regex + image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX ), max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) - images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) - ret = self.process_mm_data( - input_text=base_output.input_text, - images=None if images_are_preprocessed else base_output.images, - ) - - items = [] - input_ids = ret["input_ids"].flatten() - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.hf_config.image_token_index, - ) - for i, image in enumerate(base_output.images): - if images_are_preprocessed: - pixel_values = image.pixel_values - precomputed_features = image.precomputed_features - else: - pixel_values = ret["pixel_values"][i] - precomputed_features = None - - item = MultimodalDataItem( - pixel_values=pixel_values, - precomputed_features=precomputed_features, - modality=Modality.IMAGE, - image_offsets=image_offsets[i], - ) - items += [item] + combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) return { - "mm_items": items, "input_ids": input_ids.tolist(), + "mm_items": [combined_mm_item] if combined_mm_item is not None else [], "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, } diff --git a/python/sglang/srt/managers/multimodal_processors/kimi_vl.py b/python/sglang/srt/managers/multimodal_processors/kimi_vl.py index 86d41189e..7431a2ec1 100644 --- a/python/sglang/srt/managers/multimodal_processors/kimi_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/kimi_vl.py @@ -21,7 +21,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor): super().__init__(hf_config, server_args, _processor) self.IMAGE_TOKEN = "<|media_pad|>" self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+") - self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) + self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) async def process_mm_data_async( self, @@ -46,48 +46,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor): max_req_input_len=max_req_input_len, ) - images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) - if not images_are_preprocessed: - ret = self.process_mm_data( - input_text=base_output.input_text, - images=base_output.images, - ) - input_ids = ret["input_ids"].flatten() - image_grid_thws = ret["image_grid_hws"] - pixel_values = ret["pixel_values"] - precomputed_features = None - else: - input_ids = self._processor.tokenizer( - base_output.input_text, - return_tensors="pt", - add_special_tokens=True, - ).input_ids.flatten() - - image_grid_thws = self._extract_processor_features( - base_output.images, "image_grid_thws" - ) - precomputed_features = self._extract_processor_features( - base_output.images, "precomputed_features" - ) - pixel_values = self._extract_processor_features( - base_output.images, "pixel_values" - ) - - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.im_token_id, - ) + combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) return { "input_ids": input_ids.tolist(), - "mm_items": [ - MultimodalDataItem( - pixel_values=pixel_values, - image_grid_thws=image_grid_thws, - precomputed_features=precomputed_features, - modality=Modality.IMAGE, - image_offsets=image_offsets, - ) - ], - "im_token_id": self.im_token_id, + "mm_items": [combined_mm_item] if combined_mm_item is not None else [], + "im_token_id": self.IM_TOKEN_ID, } diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index 268d3350b..d09b61b29 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -32,8 +32,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ) self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id - self.image_token_id = hf_config.image_token_id - self.video_token_id = hf_config.video_token_id + self.IM_TOKEN_ID = hf_config.image_token_id + self.VIDEO_TOKEN_ID = hf_config.video_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = hf_config.vision_end_token_id self.NUM_TOKEN_PER_FRAME = 770 @@ -125,72 +125,45 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): async def resize_image_async(image): return resize_image(image) - images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) - if base_output.images and not images_are_preprocessed: + # Qwen-specific: resize images if they are raw Image objects + if base_output.images and isinstance(base_output.images[0], Image.Image): resize_tasks = [resize_image_async(image) for image in base_output.images] base_output.images = await asyncio.gather(*resize_tasks) - ret = self.process_mm_data( - input_text=base_output.input_text, - images=None if images_are_preprocessed else base_output.images, - ) - input_ids = ret["input_ids"].flatten().tolist() - image_offsets = self.get_mm_items_offset( - input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id - ) - image_grid_thw = None video_grid_thw = None # TODO - items = [] - if base_output.images: - if images_are_preprocessed: - image_grid_thw = self._extract_processor_features( - base_output.images, "image_grid_thws" - ) - precomputed_features = self._extract_processor_features( - base_output.images, "precomputed_features" - ) - pixel_values = self._extract_processor_features( - base_output.images, "pixel_values" - ) - else: - image_grid_thw = ret["image_grid_thw"] - pixel_values = ret["pixel_values"] - precomputed_features = None - items += [ - MultimodalDataItem( - pixel_values=pixel_values, - image_grid_thws=image_grid_thw, - video_grid_thws=video_grid_thw, - precomputed_features=precomputed_features, - image_offsets=image_offsets, - modality=Modality.IMAGE, - ) - ] + combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) + + if combined_mm_item is None: + # Note(Xinyuan): This is the case where image loading fails. + return None + + video_grid_thw = None # TODO + second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None) mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, - image_token_id=self.image_token_id, - video_token_id=self.video_token_id, + image_token_id=self.IM_TOKEN_ID, + video_token_id=self.VIDEO_TOKEN_ID, vision_start_token_id=self.vision_start_token_id, model_type=self.hf_config.model_type, tokens_per_second=getattr( self.hf_config.vision_config, "tokens_per_second", None ), - input_ids=torch.tensor(input_ids).unsqueeze(0), - image_grid_thw=image_grid_thw, + input_ids=input_ids.unsqueeze(0), + image_grid_thw=combined_mm_item.image_grid_thw, video_grid_thw=video_grid_thw, - second_per_grid_ts=ret.get("second_per_grid_ts", None), + second_per_grid_ts=second_per_grid_ts, ) mrope_positions = mrope_positions.squeeze(1) return { - "input_ids": input_ids, - "mm_items": items, + "input_ids": input_ids.tolist(), + "mm_items": [combined_mm_item], "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, - "im_token_id": self.image_token_id, - "video_token_id": self.video_token_id, + "im_token_id": self.IM_TOKEN_ID, + "video_token_id": self.VIDEO_TOKEN_ID, "mrope_positions": mrope_positions, "mrope_position_delta": mrope_position_delta, } diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9192648e8..d299edca0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -188,7 +188,7 @@ class MultimodalDataItem: # the real data, pixel_values or audio_features # data: Union[List[torch.Tensor], List[np.ndarray]] pixel_values: Union[torch.Tensor, np.ndarray] = None - image_grid_thws: Union[torch.Tensor, np.ndarray] = None + image_grid_thw: Union[torch.Tensor, np.ndarray] = None video_grid_thws: Union[torch.Tensor, np.ndarray] = None image_emb_mask: Optional[torch.Tensor] = None @@ -198,6 +198,9 @@ class MultimodalDataItem: # [num_images, (n, w, h)] tgt_size: Tuple[int, int] = None + # kimi-vl related + image_grid_hws: Optional[List[torch.Tensor]] = None + audio_features: Union[torch.Tensor, np.ndarray] = None audio_feature_lens: Optional[List[torch.Tensor]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 45b3c4572..e61715031 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -286,14 +286,26 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) vision_outputs_list = [] - for pixel_value in all_pixel_values: - # Add batch dimension for single image processing - pixel_value_batch = pixel_value.unsqueeze(0) - pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device) - pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype()) + for pixel_values_batch in all_pixel_values: + # Normalize input shape to [batch_size, channels, height, width] + if pixel_values_batch.dim() == 5: + pixel_values_batch = pixel_values_batch.squeeze(0) + elif pixel_values_batch.dim() == 3: + pixel_values_batch = pixel_values_batch.unsqueeze(0) + elif pixel_values_batch.dim() != 4: + raise ValueError( + f"Unexpected pixel_values shape: {pixel_values_batch.shape}" + ) - vision_output = self.vision_tower(pixel_values=pixel_value_batch) - vision_outputs_list.append(vision_output) + # Process each image in the batch + batch_size = pixel_values_batch.shape[0] + for i in range(batch_size): + pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1 + pixel_value = pixel_value.to( + device=self.vision_tower.device, dtype=self.language_model.dtype() + ) + vision_output = self.vision_tower(pixel_values=pixel_value) + vision_outputs_list.append(vision_output) # Concatenate all vision outputs vision_outputs = torch.cat(vision_outputs_list, dim=0) diff --git a/python/sglang/srt/models/kimi_vl.py b/python/sglang/srt/models/kimi_vl.py index 0efbf2724..9311df485 100644 --- a/python/sglang/srt/models/kimi_vl.py +++ b/python/sglang/srt/models/kimi_vl.py @@ -144,10 +144,10 @@ class KimiVLForConditionalGeneration(nn.Module): .type(self.vision_tower.dtype) .to(self.vision_tower.device) ) - image_grid_thws = torch.concat( - [item.image_grid_thws for item in items], dim=0 - ).to(self.vision_tower.device) - image_features = self.vision_tower(pixel_values, image_grid_thws) + image_grid_hws = torch.cat([item.image_grid_hws for item in items], dim=0).to( + self.vision_tower.device + ) + image_features = self.vision_tower(pixel_values, image_grid_hws) assert isinstance(image_features, list) # lengths = [x.shape[0] for x in image_features] res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 4fbac6fd4..2de6789df 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( self.visual.dtype ) - image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0) + image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() - assert image_grid_thws.dim() == 2, image_grid_thws.dim() - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws) + assert image_grid_thw.dim() == 2, image_grid_thw.dim() + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index f653401d8..4c9026c1a 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -490,10 +490,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( self.visual.dtype ) - image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0) + image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() - assert image_grid_thws.dim() == 2, image_grid_thws.dim() - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws) + assert image_grid_thw.dim() == 2, image_grid_thw.dim() + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py index 51f2f5592..f45696a3c 100644 --- a/test/srt/test_vlm_input_format.py +++ b/test/srt/test_vlm_input_format.py @@ -156,7 +156,7 @@ class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestC def _pixel_values_image_data(self, processor_output): return dict( modality="IMAGE", - image_grid_thws=processor_output["image_grid_thw"], + image_grid_thw=processor_output["image_grid_thw"], pixel_values=processor_output["pixel_values"], ) @@ -207,8 +207,8 @@ class TestKimiVLImageUnderstandsImage( def _pixel_values_image_data(self, processor_output): return dict( modality="IMAGE", - image_grid_thws=processor_output["image_grid_hws"], pixel_values=processor_output["pixel_values"], + image_grid_hws=processor_output["image_grid_hws"], )