Add examples for server token-in-token-out (#4103)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
|
||||||
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
* `is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
|
||||||
* `revision`: Adjust if a specific version of the model should be used.
|
* `revision`: Adjust if a specific version of the model should be used.
|
||||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/token_in_token_out/).
|
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
|
||||||
* `json_model_override_args`: Override model config with the provided JSON.
|
* `json_model_override_args`: Override model config with the provided JSON.
|
||||||
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,6 @@ curl -X POST http://localhost:8000/generate_stream -H "Content-Type: applicatio
|
|||||||
|
|
||||||
This will send both non-streaming and streaming requests to the server.
|
This will send both non-streaming and streaming requests to the server.
|
||||||
|
|
||||||
### [Token-In-Token-Out for RLHF](./token_in_token_out)
|
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
||||||
|
|
||||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||||
|
|||||||
@@ -29,8 +29,12 @@ def main():
|
|||||||
outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)
|
outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
decode_output = tokenizer.decode(output["output_ids"])
|
||||||
print("===============================")
|
print("===============================")
|
||||||
print(f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}")
|
print(
|
||||||
|
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python token_in_token_out_llm_server.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.test.test_utils import is_in_ci
|
||||||
|
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
from docs.backend.patch import launch_server_cmd
|
||||||
|
else:
|
||||||
|
from sglang.utils import launch_server_cmd
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Launch the server
|
||||||
|
server_process, port = launch_server_cmd(
|
||||||
|
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
|
||||||
|
)
|
||||||
|
wait_for_server(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||||
|
|
||||||
|
# Tokenize inputs
|
||||||
|
tokenizer = get_tokenizer(MODEL_PATH)
|
||||||
|
token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"input_ids": token_ids_list,
|
||||||
|
"sampling_params": sampling_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"http://localhost:{port}/generate",
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = response.json()
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
print("===============================")
|
||||||
|
decode_output = tokenizer.decode(output["output_ids"])
|
||||||
|
print(
|
||||||
|
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
terminate_process(server_process)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python token_in_token_out_vlm_server.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.test.test_utils import DEFAULT_IMAGE_URL, is_in_ci
|
||||||
|
from sglang.utils import print_highlight, terminate_process, wait_for_server
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
from docs.backend.patch import launch_server_cmd
|
||||||
|
else:
|
||||||
|
from sglang.utils import launch_server_cmd
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_PATH = "Qwen/Qwen2-VL-2B"
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_ids() -> Tuple[list[int], list]:
|
||||||
|
chat_template = get_chat_template_by_model_path(MODEL_PATH)
|
||||||
|
text = f"{chat_template.image_token}What is in this picture?"
|
||||||
|
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
|
||||||
|
image_data = [DEFAULT_IMAGE_URL]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=images,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs.input_ids[0].tolist(), image_data
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Launch the server
|
||||||
|
server_process, port = launch_server_cmd(
|
||||||
|
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
|
||||||
|
)
|
||||||
|
wait_for_server(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
input_ids, image_data = get_input_ids()
|
||||||
|
|
||||||
|
sampling_params = {
|
||||||
|
"temperature": 0.8,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"image_data": image_data,
|
||||||
|
"sampling_params": sampling_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"http://localhost:{port}/generate",
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = response.json()
|
||||||
|
print("===============================")
|
||||||
|
print(f"Output token ids: ", output["output_ids"])
|
||||||
|
|
||||||
|
terminate_process(server_process)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user