Add examples for server token-in-token-out (#4103)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -44,6 +44,6 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
|
||||
|
||||
This will send both non-streaming and streaming requests to the server.
|
||||
|
||||
### [Token-In-Token-Out for RLHF](./token_in_token_out)
|
||||
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
||||
|
||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""
|
||||
This example demonstrates how to provide tokenized ids to LLM as input instead of text prompt, i.e. a token-in-token-out workflow.
|
||||
"""
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Tokenize inputs
|
||||
tokenizer = get_tokenizer(MODEL_PATH)
|
||||
token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(model_path=MODEL_PATH, skip_tokenizer_init=True)
|
||||
|
||||
outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,75 +0,0 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import DEFAULT_IMAGE_URL
|
||||
|
||||
|
||||
def get_input_ids(
|
||||
server_args: ServerArgs, model_config: ModelConfig
|
||||
) -> Tuple[list[int], list]:
|
||||
chat_template = get_chat_template_by_model_path(model_config.model_path)
|
||||
text = f"{chat_template.image_token}What is in this picture?"
|
||||
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
|
||||
image_data = [DEFAULT_IMAGE_URL]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_config.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return inputs.input_ids[0].tolist(), image_data
|
||||
|
||||
|
||||
def token_in_out_example(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
input_ids, image_data = get_input_ids(
|
||||
server_args,
|
||||
ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
),
|
||||
)
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
output = backend.generate(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
sampling_params={
|
||||
"temperature": 0.8,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
)
|
||||
|
||||
print("===============================")
|
||||
print(f"Output token ids: ", output["output_ids"])
|
||||
|
||||
backend.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = [
|
||||
"--model-path=Qwen/Qwen2-VL-2B",
|
||||
]
|
||||
args = parser.parse_args(args=args)
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server_args.skip_tokenizer_init = True
|
||||
token_in_out_example(server_args)
|
||||
Reference in New Issue
Block a user