diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index afc46c50f..b92c6ecdb 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -151,10 +151,6 @@ class Engine: The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ - modalities_list = [] - if image_data is not None: - modalities_list.append("image") - obj = GenerateReqInput( text=prompt, input_ids=input_ids, @@ -165,7 +161,6 @@ class Engine: top_logprobs_num=top_logprobs_num, token_ids_logprob=token_ids_logprob, lora_path=lora_path, - modalities=modalities_list, custom_logit_processor=custom_logit_processor, return_hidden_states=return_hidden_states, stream=stream, diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index 462b3beef..c976f24f7 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -139,8 +139,6 @@ class BaseMultimodalProcessor(ABC): else: multimodal_tokens.image_token = multimodal_tokens.image_token - assert isinstance(prompt, str) - if isinstance(prompt, list) and return_text: assert len(prompt) and isinstance(prompt[0], int) prompt = self._processor.tokenizer.decode(prompt) @@ -204,7 +202,16 @@ class BaseMultimodalProcessor(ABC): continue image_sizes += frames[0].size * len(frames) - hashes += [hash(image_file)] * len(frames) + + # Generate a hashable value for the image file + if isinstance(image_file, Image.Image): + # For PIL.Image objects, use the ID as a hashable value + hash_value = hash(id(image_file)) + else: + # For other types (strings, etc.), use the regular hash + hash_value = hash(image_file) + + hashes += [hash_value] * len(frames) images += frames image_index += 1 if frames_to_process != 0: diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index f1575956c..9d0d6d843 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -5,7 +5,7 @@ from typing import List, Union import torch from PIL import Image -from sglang.srt.managers.multimodal_processor import ( +from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) from sglang.srt.managers.multimodal_processors.base_processor import ( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 9dc222291..498bc58cc 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -566,10 +566,14 @@ def encode_video(video_path, frame_count_limit=None): return frames -def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]: +def load_image( + image_file: Union[Image.Image, str, bytes] +) -> tuple[Image.Image, tuple[int, int]]: image = image_size = None - - if isinstance(image_file, bytes): + if isinstance(image_file, Image.Image): + image = image_file + image_size = (image.width, image.height) + elif isinstance(image_file, bytes): image = Image.open(BytesIO(image_file)) elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))