From fb4c9c3a30acac894d2432806efd75a297bc04aa Mon Sep 17 00:00:00 2001 From: Shenggui Li Date: Sat, 15 Feb 2025 05:27:29 +0800 Subject: [PATCH] [fix] added support for vlm in offline inference (#3548) --- .../engine/offline_batch_inference_vlm.py | 67 +++++++++++++++++++ python/sglang/srt/entrypoints/engine.py | 13 ++++ 2 files changed, 80 insertions(+) create mode 100644 examples/runtime/engine/offline_batch_inference_vlm.py diff --git a/examples/runtime/engine/offline_batch_inference_vlm.py b/examples/runtime/engine/offline_batch_inference_vlm.py new file mode 100644 index 000000000..808d0fce9 --- /dev/null +++ b/examples/runtime/engine/offline_batch_inference_vlm.py @@ -0,0 +1,67 @@ +""" +Usage: +python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template=qwen2-vl +""" + +import argparse +import dataclasses + +from transformers import AutoProcessor + +import sglang as sgl +from sglang.srt.openai_api.adapter import v1_chat_generate_request +from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.server_args import ServerArgs + + +def main( + server_args: ServerArgs, +): + # Create an LLM. + vlm = sgl.Engine(**dataclasses.asdict(server_args)) + + # prepare prompts. + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true", + }, + }, + ], + } + ] + chat_request = ChatCompletionRequest( + messages=messages, + model=server_args.model_path, + temperature=0.8, + top_p=0.95, + ) + gen_request, _ = v1_chat_generate_request( + [chat_request], + vlm.tokenizer_manager, + ) + + outputs = vlm.generate( + input_ids=gen_request.input_ids, + image_data=gen_request.image_data, + sampling_params=gen_request.sampling_params, + ) + + print("===============================") + print(f"Prompt: {messages[0]['content'][0]['text']}") + print(f"Generated text: {outputs['text']}") + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + main(server_args) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b0e780706..93bd184c6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -115,6 +115,9 @@ class Engine: sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -126,14 +129,20 @@ class Engine: The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ + modalities_list = [] + if image_data is not None: + modalities_list.append("image") + obj = GenerateReqInput( text=prompt, input_ids=input_ids, sampling_params=sampling_params, + image_data=image_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, lora_path=lora_path, + modalities=modalities_list, custom_logit_processor=custom_logit_processor, stream=stream, ) @@ -162,6 +171,9 @@ class Engine: sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -177,6 +189,7 @@ class Engine: text=prompt, input_ids=input_ids, sampling_params=sampling_params, + image_data=image_data, return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num,