Add examples for server token-in-token-out (#4103)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
||||
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
||||
* `revision`: Adjust if a specific version of the model should be used.
|
||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/).
|
||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
|
||||
* `json_model_override_args`: Override model config with the provided JSON.
|
||||
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
||||
|
||||
|
||||
@@ -44,6 +44,6 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
|
||||
|
||||
This will send both non-streaming and streaming requests to the server.
|
||||
|
||||
### [Token-In-Token-Out for RLHF](./token_in_token_out)
|
||||
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
||||
|
||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||
|
||||
@@ -29,8 +29,12 @@ def main():
|
||||
outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
decode_output = tokenizer.decode(output["output_ids"])
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
|
||||
print(
|
||||
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python token_in_token_out_llm_server.py
|
||||
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import is_in_ci
|
||||
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
||||
|
||||
if is_in_ci():
|
||||
from docs.backend.patch import launch_server_cmd
|
||||
else:
|
||||
from sglang.utils import launch_server_cmd
|
||||
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def main():
|
||||
# Launch the server
|
||||
server_process, port = launch_server_cmd(
|
||||
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
|
||||
)
|
||||
wait_for_server(f"http://localhost:{port}")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Tokenize inputs
|
||||
tokenizer = get_tokenizer(MODEL_PATH)
|
||||
token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
|
||||
json_data = {
|
||||
"input_ids": token_ids_list,
|
||||
"sampling_params": sampling_params,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
outputs = response.json()
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
decode_output = tokenizer.decode(output["output_ids"])
|
||||
print(
|
||||
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
|
||||
)
|
||||
print()
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python token_in_token_out_vlm_server.py
|
||||
|
||||
"""
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import DEFAULT_IMAGE_URL, is_in_ci
|
||||
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
||||
|
||||
if is_in_ci():
|
||||
from docs.backend.patch import launch_server_cmd
|
||||
else:
|
||||
from sglang.utils import launch_server_cmd
|
||||
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen2-VL-2B"
|
||||
|
||||
|
||||
def get_input_ids() -> Tuple[list[int], list]:
|
||||
chat_template = get_chat_template_by_model_path(MODEL_PATH)
|
||||
text = f"{chat_template.image_token}What is in this picture?"
|
||||
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
|
||||
image_data = [DEFAULT_IMAGE_URL]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return inputs.input_ids[0].tolist(), image_data
|
||||
|
||||
|
||||
def main():
|
||||
# Launch the server
|
||||
server_process, port = launch_server_cmd(
|
||||
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
|
||||
)
|
||||
wait_for_server(f"http://localhost:{port}")
|
||||
|
||||
input_ids, image_data = get_input_ids()
|
||||
|
||||
sampling_params = {
|
||||
"temperature": 0.8,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
json_data = {
|
||||
"input_ids": input_ids,
|
||||
"image_data": image_data,
|
||||
"sampling_params": sampling_params,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
output = response.json()
|
||||
print("===============================")
|
||||
print(f"Output token ids: ", output["output_ids"])
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user