example: add vlm to token in & out example (#3941)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
||||
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
||||
* `revision`: Adjust if a specific version of the model should be used.
|
||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out_llm.py).
|
||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/).
|
||||
* `json_model_override_args`: Override model config with the provided JSON.
|
||||
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
||||
|
||||
|
||||
@@ -9,15 +9,15 @@ SGLang provides a direct inference engine without the need for an HTTP server. T
|
||||
|
||||
## Examples
|
||||
|
||||
### 1. [Offline Batch Inference](./offline_batch_inference.py)
|
||||
### [Offline Batch Inference](./offline_batch_inference.py)
|
||||
|
||||
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
|
||||
|
||||
### 2. [Embedding Generation](./embedding.py)
|
||||
### [Embedding Generation](./embedding.py)
|
||||
|
||||
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
|
||||
|
||||
### 3. [Custom Server](./custom_server.py)
|
||||
### [Custom Server](./custom_server.py)
|
||||
|
||||
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
|
||||
|
||||
@@ -43,3 +43,7 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
|
||||
```
|
||||
|
||||
This will send both non-streaming and streaming requests to the server.
|
||||
|
||||
### [Token-In-Token-Out for RLHF](./token_in_token_out)
|
||||
|
||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||
|
||||
@@ -30,7 +30,7 @@ def main():
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated token ids: {output['token_ids']}")
|
||||
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
@@ -0,0 +1,75 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import DEFAULT_IMAGE_URL
|
||||
|
||||
|
||||
def get_input_ids(
|
||||
server_args: ServerArgs, model_config: ModelConfig
|
||||
) -> Tuple[list[int], list]:
|
||||
chat_template = get_chat_template_by_model_path(model_config.model_path)
|
||||
text = f"{chat_template.image_token}What is in this picture?"
|
||||
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
|
||||
image_data = [DEFAULT_IMAGE_URL]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_config.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return inputs.input_ids[0].tolist(), image_data
|
||||
|
||||
|
||||
def token_in_out_example(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
input_ids, image_data = get_input_ids(
|
||||
server_args,
|
||||
ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
),
|
||||
)
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
output = backend.generate(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
sampling_params={
|
||||
"temperature": 0.8,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
)
|
||||
|
||||
print("===============================")
|
||||
print(f"Output token ids: ", output["output_ids"])
|
||||
|
||||
backend.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = [
|
||||
"--model-path=Qwen/Qwen2-VL-2B",
|
||||
]
|
||||
args = parser.parse_args(args=args)
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server_args.skip_tokenizer_init = True
|
||||
token_in_out_example(server_args)
|
||||
@@ -40,7 +40,7 @@ class ModelConfig:
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
model_override_args: Optional[str] = None,
|
||||
is_embedding: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
|
||||
@@ -42,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalDecodeReq,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
@@ -104,7 +103,6 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
pyspy_dump_schedulers,
|
||||
set_gpu_proc_affinity,
|
||||
set_random_seed,
|
||||
@@ -1199,7 +1197,6 @@ class Scheduler:
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
batch.output_ids = next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
# we can use the correct values in output processing.
|
||||
@@ -1480,7 +1477,6 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
@@ -1580,11 +1576,11 @@ class Scheduler:
|
||||
if req.top_logprobs_num > 0:
|
||||
req.input_top_logprobs_val = [None]
|
||||
req.input_top_logprobs_idx = [None]
|
||||
|
||||
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
||||
req.temp_input_token_ids_logprobs_idx
|
||||
)
|
||||
for val, idx in zip(
|
||||
req.temp_input_top_logprobs_val,
|
||||
req.temp_input_top_logprobs_idx,
|
||||
strict=True,
|
||||
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
|
||||
):
|
||||
req.input_top_logprobs_val.extend(val)
|
||||
req.input_top_logprobs_idx.extend(idx)
|
||||
@@ -1779,7 +1775,6 @@ class Scheduler:
|
||||
if rids:
|
||||
if self.model_config.is_multimodal_gen:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchTokenIDOut(
|
||||
rids,
|
||||
|
||||
@@ -11,7 +11,7 @@ import math
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, cast
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
||||
|
||||
import gguf
|
||||
import huggingface_hub
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM, PretrainedConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
|
||||
@@ -43,10 +43,15 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Ins
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"
|
||||
|
||||
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
||||
|
||||
DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||
|
||||
|
||||
def is_in_ci():
|
||||
"""Return whether it is in CI runner."""
|
||||
|
||||
@@ -5,13 +5,18 @@ python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_st
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_IMAGE_URL,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_VLM_MODEL_NAME,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -29,6 +34,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--skip-tokenizer-init", "--stream-output"],
|
||||
)
|
||||
cls.eos_token_id = [119690]
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
||||
)
|
||||
@@ -45,9 +51,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
top_logprobs_num=0,
|
||||
n=1,
|
||||
):
|
||||
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
||||
0
|
||||
].tolist()
|
||||
input_ids = self.get_input_ids(prompt_text)
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
@@ -104,7 +108,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
|
||||
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
max_new_tokens = 32
|
||||
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
|
||||
input_ids = self.get_input_ids("The capital of France is")
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
@@ -114,7 +118,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
@@ -125,6 +129,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
output_ids = ret["output_ids"]
|
||||
print("output from non-streaming request:")
|
||||
print(output_ids)
|
||||
print(self.tokenizer.decode(output_ids, skip_special_tokens=True))
|
||||
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response_stream = requests.post(
|
||||
@@ -135,7 +142,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
@@ -143,13 +150,10 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
output_ids = ret["output_ids"]
|
||||
print("output from non-streaming request:")
|
||||
print(output_ids)
|
||||
|
||||
response_stream_json = []
|
||||
for line in response_stream.iter_lines():
|
||||
print(line)
|
||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||
response_stream_json.append(json.loads(line[6:]))
|
||||
out_stream_ids = []
|
||||
@@ -157,6 +161,8 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
out_stream_ids += x["output_ids"]
|
||||
print("output from streaming request:")
|
||||
print(out_stream_ids)
|
||||
print(self.tokenizer.decode(out_stream_ids, skip_special_tokens=True))
|
||||
|
||||
assert output_ids == out_stream_ids
|
||||
|
||||
def test_simple_decode(self):
|
||||
@@ -175,6 +181,46 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
def test_simple_decode_stream(self):
|
||||
self.run_decode_stream()
|
||||
|
||||
def get_input_ids(self, prompt_text) -> list[int]:
|
||||
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
||||
0
|
||||
].tolist()
|
||||
return input_ids
|
||||
|
||||
|
||||
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.image_url = DEFAULT_IMAGE_URL
|
||||
response = requests.get(cls.image_url)
|
||||
cls.image = Image.open(BytesIO(response.content))
|
||||
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False)
|
||||
cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True)
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--skip-tokenizer-init"],
|
||||
)
|
||||
cls.eos_token_id = [cls.tokenizer.eos_token_id]
|
||||
|
||||
def get_input_ids(self, _prompt_text) -> list[int]:
|
||||
chat_template = get_chat_template_by_model_path(self.model)
|
||||
text = f"{chat_template.image_token}What is in this picture?"
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=[self.image],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return inputs.input_ids[0].tolist()
|
||||
|
||||
def test_simple_decode_stream(self):
|
||||
# TODO mick
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user