From 6fcd6d7d6dec7aea858d7441effd8a04b6d05474 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 27 Oct 2024 14:02:34 -0700 Subject: [PATCH] Support token ids in `engine.generate` (#1820) --- examples/runtime/engine/input_ids.py | 39 ++++++++++++++++++++++++++++ python/sglang/srt/server.py | 14 +++++++--- test/srt/test_srt_engine.py | 23 ++++++++++++++++ 3 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 examples/runtime/engine/input_ids.py diff --git a/examples/runtime/engine/input_ids.py b/examples/runtime/engine/input_ids.py new file mode 100644 index 000000000..89de6f63a --- /dev/null +++ b/examples/runtime/engine/input_ids.py @@ -0,0 +1,39 @@ +""" +This example demonstrates how to provide tokenized ids as input instead of text prompt +""" + +import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer + +MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct" + +def main(): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = {"temperature": 0.8, "top_p": 0.95} + + # Tokenize inputs + tokenizer = get_tokenizer(MODEL_PATH) + token_ids_list = [tokenizer.encode(prompt) for prompt in prompts] + + # Create an LLM. + # You can also specify `skip_tokenizer_init=True`, but it requires explicit detokenization at the end + llm = sgl.Engine(model_path=MODEL_PATH) + + outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params) + # Print the outputs. + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated Text: {output['text']}") + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8cf86a091..64f6c6f55 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -742,18 +742,20 @@ class Engine: def generate( self, - prompt: Union[str, List[str]], + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, sampling_params: Optional[Dict] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, stream: bool = False, ): - # TODO (ByronHsu): refactor to reduce the duplicated code - obj = GenerateReqInput( text=prompt, + input_ids=input_ids, sampling_params=sampling_params, return_logprob=return_logprob, logprob_start_len=logprob_start_len, @@ -791,8 +793,11 @@ class Engine: async def async_generate( self, - prompt: Union[str, List[str]], + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, sampling_params: Optional[Dict] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, @@ -801,6 +806,7 @@ class Engine: ): obj = GenerateReqInput( text=prompt, + input_ids=input_ids, sampling_params=sampling_params, return_logprob=return_logprob, logprob_start_len=logprob_start_len, diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 8743e0ef9..38781b0e2 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -9,6 +9,7 @@ import unittest from types import SimpleNamespace import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.test.few_shot_gsm8k_engine import run_eval from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST @@ -106,6 +107,28 @@ class TestSRTEngine(unittest.TestCase): metrics = run_eval(args) assert metrics["accuracy"] > 0.7 + def test_5_prompt_input_ids_consistency(self): + prompt = "The capital of UK is" + + + model_path = DEFAULT_MODEL_NAME_FOR_TEST + engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") + sampling_params = {"temperature": 0, "max_new_tokens": 8} + out1 = engine.generate(prompt, sampling_params)["text"] + + tokenizer = get_tokenizer(model_path) + token_ids = tokenizer.encode(prompt) + out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)["text"] + + engine.shutdown() + + print("==== Answer 1 ====") + print(out1) + + print("==== Answer 2 ====") + print(out2) + assert out1 == out2, f"{out1} != {out2}" + if __name__ == "__main__": unittest.main()