From 583d6af71b68a8b8464dac2d1c60d0cb5dd65738 Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 5 Mar 2025 14:18:26 +0800 Subject: [PATCH] example: add vlm to token in & out example (#3941) Co-authored-by: zhaochenyang20 --- docs/backend/server_arguments.md | 2 +- examples/runtime/engine/readme.md | 10 ++- .../token_in_token_out_llm.py | 2 +- .../token_in_token_out_vlm.py | 75 +++++++++++++++++++ python/sglang/srt/configs/model_config.py | 2 +- python/sglang/srt/managers/scheduler.py | 13 +--- python/sglang/srt/model_loader/loader.py | 6 +- python/sglang/test/test_utils.py | 5 ++ test/srt/test_skip_tokenizer_init.py | 68 ++++++++++++++--- 9 files changed, 154 insertions(+), 29 deletions(-) rename examples/runtime/engine/{ => token_in_token_out}/token_in_token_out_llm.py (98%) create mode 100644 examples/runtime/engine/token_in_token_out/token_in_token_out_vlm.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 0d6d279ad..9777b9b4c 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma * `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template). * `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. * `revision`: Adjust if a specific version of the model should be used. -* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out_llm.py). +* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/). * `json_model_override_args`: Override model config with the provided JSON. * `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. diff --git a/examples/runtime/engine/readme.md b/examples/runtime/engine/readme.md index 986e0b12e..82174e082 100644 --- a/examples/runtime/engine/readme.md +++ b/examples/runtime/engine/readme.md @@ -9,15 +9,15 @@ SGLang provides a direct inference engine without the need for an HTTP server. T ## Examples -### 1. [Offline Batch Inference](./offline_batch_inference.py) +### [Offline Batch Inference](./offline_batch_inference.py) In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors. -### 2. [Embedding Generation](./embedding.py) +### [Embedding Generation](./embedding.py) In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation. -### 3. [Custom Server](./custom_server.py) +### [Custom Server](./custom_server.py) This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints. @@ -43,3 +43,7 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio ``` This will send both non-streaming and streaming requests to the server. + +### [Token-In-Token-Out for RLHF](./token_in_token_out) + +In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output. diff --git a/examples/runtime/engine/token_in_token_out_llm.py b/examples/runtime/engine/token_in_token_out/token_in_token_out_llm.py similarity index 98% rename from examples/runtime/engine/token_in_token_out_llm.py rename to examples/runtime/engine/token_in_token_out/token_in_token_out_llm.py index 8bcd6e3ec..97efc30f9 100644 --- a/examples/runtime/engine/token_in_token_out_llm.py +++ b/examples/runtime/engine/token_in_token_out/token_in_token_out_llm.py @@ -30,7 +30,7 @@ def main(): # Print the outputs. for prompt, output in zip(prompts, outputs): print("===============================") - print(f"Prompt: {prompt}\nGenerated token ids: {output['token_ids']}") + print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}") # The __main__ condition is necessary here because we use "spawn" to create subprocesses diff --git a/examples/runtime/engine/token_in_token_out/token_in_token_out_vlm.py b/examples/runtime/engine/token_in_token_out/token_in_token_out_vlm.py new file mode 100644 index 000000000..3c3717dcb --- /dev/null +++ b/examples/runtime/engine/token_in_token_out/token_in_token_out_vlm.py @@ -0,0 +1,75 @@ +import argparse +import dataclasses +from io import BytesIO +from typing import Tuple + +import requests +from PIL import Image +from transformers import AutoProcessor + +from sglang import Engine +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.server_args import ServerArgs +from sglang.test.test_utils import DEFAULT_IMAGE_URL + + +def get_input_ids( + server_args: ServerArgs, model_config: ModelConfig +) -> Tuple[list[int], list]: + chat_template = get_chat_template_by_model_path(model_config.model_path) + text = f"{chat_template.image_token}What is in this picture?" + images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))] + image_data = [DEFAULT_IMAGE_URL] + + processor = AutoProcessor.from_pretrained( + model_config.model_path, trust_remote_code=server_args.trust_remote_code + ) + + inputs = processor( + text=[text], + images=images, + return_tensors="pt", + ) + + return inputs.input_ids[0].tolist(), image_data + + +def token_in_out_example( + server_args: ServerArgs, +): + input_ids, image_data = get_input_ids( + server_args, + ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + model_override_args=server_args.json_model_override_args, + ), + ) + backend = Engine(**dataclasses.asdict(server_args)) + + output = backend.generate( + input_ids=input_ids, + image_data=image_data, + sampling_params={ + "temperature": 0.8, + "max_new_tokens": 32, + }, + ) + + print("===============================") + print(f"Output token ids: ", output["output_ids"]) + + backend.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = [ + "--model-path=Qwen/Qwen2-VL-2B", + ] + args = parser.parse_args(args=args) + server_args = ServerArgs.from_cli_args(args) + server_args.skip_tokenizer_init = True + token_in_out_example(server_args) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 8880288f1..64ef15cf7 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -40,7 +40,7 @@ class ModelConfig: trust_remote_code: bool = True, revision: Optional[str] = None, context_length: Optional[int] = None, - model_override_args: Optional[dict] = None, + model_override_args: Optional[str] = None, is_embedding: Optional[bool] = None, dtype: str = "auto", quantization: Optional[str] = None, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5c181570a..c47405c43 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -42,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, - BatchMultimodalDecodeReq, BatchTokenIDOut, CloseSessionReqInput, FlushCacheReq, @@ -104,7 +103,6 @@ from sglang.srt.utils import ( crash_on_warnings, get_bool_env_var, get_zmq_socket, - kill_itself_when_parent_died, pyspy_dump_schedulers, set_gpu_proc_affinity, set_random_seed, @@ -1199,7 +1197,6 @@ class Scheduler: self.spec_num_total_forward_ct += batch.batch_size() self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids - # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that # we can use the correct values in output processing. @@ -1480,7 +1477,6 @@ class Scheduler: batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool.free_group_end() @@ -1580,11 +1576,11 @@ class Scheduler: if req.top_logprobs_num > 0: req.input_top_logprobs_val = [None] req.input_top_logprobs_idx = [None] - + assert len(req.temp_input_token_ids_logprobs_val) == len( + req.temp_input_token_ids_logprobs_idx + ) for val, idx in zip( - req.temp_input_top_logprobs_val, - req.temp_input_top_logprobs_idx, - strict=True, + req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx ): req.input_top_logprobs_val.extend(val) req.input_top_logprobs_idx.extend(idx) @@ -1779,7 +1775,6 @@ class Scheduler: if rids: if self.model_config.is_multimodal_gen: raise NotImplementedError() - self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( rids, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 9e6b09488..eff4aa5f3 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -11,7 +11,7 @@ import math import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast import gguf import huggingface_hub @@ -19,7 +19,7 @@ import numpy as np import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn -from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from sglang.srt.configs.device_config import DeviceConfig @@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader): Returns the path to the downloaded model, or None if the model is not downloaded from ModelScope.""" - if "SGLANG_USE_MODELSCOPE" in os.environ: + if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True": # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. # pylint: disable=C. diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 05e4fc558..a8bf674a9 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -43,10 +43,15 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Ins DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct" +DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B" + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B" +DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" +DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + def is_in_ci(): """Return whether it is in CI runner.""" diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index d714a593c..41787e2c1 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -5,13 +5,18 @@ python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_st import json import unittest +from io import BytesIO import requests -from transformers import AutoTokenizer +from PIL import Image +from transformers import AutoProcessor, AutoTokenizer +from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( + DEFAULT_IMAGE_URL, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_VLM_MODEL_NAME, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -29,6 +34,7 @@ class TestSkipTokenizerInit(unittest.TestCase): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=["--skip-tokenizer-init", "--stream-output"], ) + cls.eos_token_id = [119690] cls.tokenizer = AutoTokenizer.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False ) @@ -45,9 +51,7 @@ class TestSkipTokenizerInit(unittest.TestCase): top_logprobs_num=0, n=1, ): - input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][ - 0 - ].tolist() + input_ids = self.get_input_ids(prompt_text) response = requests.post( self.base_url + "/generate", @@ -104,7 +108,7 @@ class TestSkipTokenizerInit(unittest.TestCase): def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1): max_new_tokens = 32 - input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is + input_ids = self.get_input_ids("The capital of France is") requests.post(self.base_url + "/flush_cache") response = requests.post( self.base_url + "/generate", @@ -114,7 +118,7 @@ class TestSkipTokenizerInit(unittest.TestCase): "temperature": 0 if n == 1 else 0.5, "max_new_tokens": max_new_tokens, "n": n, - "stop_token_ids": [119690], + "stop_token_ids": self.eos_token_id, }, "stream": False, "return_logprob": return_logprob, @@ -125,6 +129,9 @@ class TestSkipTokenizerInit(unittest.TestCase): ret = response.json() print(json.dumps(ret)) output_ids = ret["output_ids"] + print("output from non-streaming request:") + print(output_ids) + print(self.tokenizer.decode(output_ids, skip_special_tokens=True)) requests.post(self.base_url + "/flush_cache") response_stream = requests.post( @@ -135,7 +142,7 @@ class TestSkipTokenizerInit(unittest.TestCase): "temperature": 0 if n == 1 else 0.5, "max_new_tokens": max_new_tokens, "n": n, - "stop_token_ids": [119690], + "stop_token_ids": self.eos_token_id, }, "stream": True, "return_logprob": return_logprob, @@ -143,13 +150,10 @@ class TestSkipTokenizerInit(unittest.TestCase): "logprob_start_len": 0, }, ) - ret = response.json() - output_ids = ret["output_ids"] - print("output from non-streaming request:") - print(output_ids) response_stream_json = [] for line in response_stream.iter_lines(): + print(line) if line.startswith(b"data: ") and line[6:] != b"[DONE]": response_stream_json.append(json.loads(line[6:])) out_stream_ids = [] @@ -157,6 +161,8 @@ class TestSkipTokenizerInit(unittest.TestCase): out_stream_ids += x["output_ids"] print("output from streaming request:") print(out_stream_ids) + print(self.tokenizer.decode(out_stream_ids, skip_special_tokens=True)) + assert output_ids == out_stream_ids def test_simple_decode(self): @@ -175,6 +181,46 @@ class TestSkipTokenizerInit(unittest.TestCase): def test_simple_decode_stream(self): self.run_decode_stream() + def get_input_ids(self, prompt_text) -> list[int]: + input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][ + 0 + ].tolist() + return input_ids + + +class TestSkipTokenizerInitVLM(TestSkipTokenizerInit): + @classmethod + def setUpClass(cls): + cls.image_url = DEFAULT_IMAGE_URL + response = requests.get(cls.image_url) + cls.image = Image.open(BytesIO(response.content)) + cls.model = DEFAULT_SMALL_VLM_MODEL_NAME + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False) + cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--skip-tokenizer-init"], + ) + cls.eos_token_id = [cls.tokenizer.eos_token_id] + + def get_input_ids(self, _prompt_text) -> list[int]: + chat_template = get_chat_template_by_model_path(self.model) + text = f"{chat_template.image_token}What is in this picture?" + inputs = self.processor( + text=[text], + images=[self.image], + return_tensors="pt", + ) + + return inputs.input_ids[0].tolist() + + def test_simple_decode_stream(self): + # TODO mick + pass + if __name__ == "__main__": unittest.main()