example: add vlm to token in & out example (#3941)

Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
Mick
2025-03-05 14:18:26 +08:00
committed by GitHub
parent e074d84e5b
commit 583d6af71b
9 changed files with 154 additions and 29 deletions

View File

@@ -9,15 +9,15 @@ SGLang provides a direct inference engine without the need for an HTTP server. T
## Examples
### 1. [Offline Batch Inference](./offline_batch_inference.py)
### [Offline Batch Inference](./offline_batch_inference.py)
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
### 2. [Embedding Generation](./embedding.py)
### [Embedding Generation](./embedding.py)
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
### 3. [Custom Server](./custom_server.py)
### [Custom Server](./custom_server.py)
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
@@ -43,3 +43,7 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
```
This will send both non-streaming and streaming requests to the server.
### [Token-In-Token-Out for RLHF](./token_in_token_out)
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.

View File

@@ -30,7 +30,7 @@ def main():
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated token ids: {output['token_ids']}")
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
# The __main__ condition is necessary here because we use "spawn" to create subprocesses

View File

@@ -0,0 +1,75 @@
import argparse
import dataclasses
from io import BytesIO
from typing import Tuple
import requests
from PIL import Image
from transformers import AutoProcessor
from sglang import Engine
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import DEFAULT_IMAGE_URL
def get_input_ids(
server_args: ServerArgs, model_config: ModelConfig
) -> Tuple[list[int], list]:
chat_template = get_chat_template_by_model_path(model_config.model_path)
text = f"{chat_template.image_token}What is in this picture?"
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
image_data = [DEFAULT_IMAGE_URL]
processor = AutoProcessor.from_pretrained(
model_config.model_path, trust_remote_code=server_args.trust_remote_code
)
inputs = processor(
text=[text],
images=images,
return_tensors="pt",
)
return inputs.input_ids[0].tolist(), image_data
def token_in_out_example(
server_args: ServerArgs,
):
input_ids, image_data = get_input_ids(
server_args,
ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=server_args.json_model_override_args,
),
)
backend = Engine(**dataclasses.asdict(server_args))
output = backend.generate(
input_ids=input_ids,
image_data=image_data,
sampling_params={
"temperature": 0.8,
"max_new_tokens": 32,
},
)
print("===============================")
print(f"Output token ids: ", output["output_ids"])
backend.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = [
"--model-path=Qwen/Qwen2-VL-2B",
]
args = parser.parse_args(args=args)
server_args = ServerArgs.from_cli_args(args)
server_args.skip_tokenizer_init = True
token_in_out_example(server_args)