From 35bdb48557d6b55e1bfadbadd1084cb23c56f7f4 Mon Sep 17 00:00:00 2001 From: Shi Shuai <126407087+shuaills@users.noreply.github.com> Date: Sun, 29 Dec 2024 20:28:27 +0000 Subject: [PATCH] [Feature] Get Token IDs with Engine.generate() (#2636) Co-authored-by: Chayenne --- .../srt/managers/detokenizer_manager.py | 2 + python/sglang/srt/managers/io_struct.py | 8 ++- python/sglang/srt/managers/scheduler.py | 10 +++- .../sglang/srt/managers/tokenizer_manager.py | 7 +++ python/sglang/srt/server_args.py | 7 +++ test/srt/run_suite.py | 1 + test/srt/test_engine_token_ids.py | 59 +++++++++++++++++++ 7 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 test/srt/test_engine_token_ids.py diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b4bc1e7a4..fd77d338e 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -181,6 +181,8 @@ class DetokenizerManager: finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, prompt_tokens=recv_obj.prompt_tokens, + origin_input_ids=recv_obj.origin_input_ids, + output_ids=recv_obj.output_ids, completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, input_token_logprobs_val=recv_obj.input_token_logprobs_val, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 16db89a0a..5fdaef188 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -323,7 +323,9 @@ class BatchTokenIDOut: decoded_texts: List[str] decode_ids: List[int] read_offsets: List[int] - # Only used when `--skip-tokenizer-init` + # Only used when --return-token-ids` is set + origin_input_ids: Optional[List[int]] + # Only used when `--skip-tokenizer-init` or `--return-token-ids` is set output_ids: Optional[List[int]] # Detokenization configs skip_special_tokens: List[bool] @@ -354,6 +356,10 @@ class BatchStrOut: # The output decoded strings output_strs: List[str] + # The token ids + origin_input_ids: Optional[List[int]] + output_ids: Optional[List[int]] + # Token counts prompt_tokens: List[int] completion_tokens: List[int] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1f8207edc..3abaa1a6c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1218,6 +1218,7 @@ class Scheduler: decode_ids_list = [] read_offsets = [] output_ids = [] + origin_input_ids = [] skip_special_tokens = [] spaces_between_special_tokens = [] @@ -1266,8 +1267,14 @@ class Scheduler: decode_ids, read_offset = req.init_incremental_detokenize() decode_ids_list.append(decode_ids) read_offsets.append(read_offset) - if self.skip_tokenizer_init: + if self.skip_tokenizer_init or self.server_args.return_token_ids: output_ids.append(req.output_ids) + else: + output_ids = None + if self.server_args.return_token_ids: + origin_input_ids.append(req.origin_input_ids) + else: + origin_input_ids = None skip_special_tokens.append(req.sampling_params.skip_special_tokens) spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens @@ -1299,6 +1306,7 @@ class Scheduler: decoded_texts, decode_ids_list, read_offsets, + origin_input_ids, output_ids, skip_special_tokens, spaces_between_special_tokens, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1c81f5e50..e12d9cdb4 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -663,6 +663,13 @@ class TokenizerManager: "text": recv_obj.output_strs[i], "meta_info": meta_info, } + if self.server_args.return_token_ids: + out_dict.update( + { + "input_ids": recv_obj.origin_input_ids[i], + "output_ids": recv_obj.output_ids[i], + } + ) elif isinstance(recv_obj, BatchTokenIDOut): out_dict = { "token_ids": recv_obj.output_ids[i], diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 23beb3eb8..f7177c2d9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -54,6 +54,7 @@ class ServerArgs: chat_template: Optional[str] = None is_embedding: bool = False revision: Optional[str] = None + return_token_ids: bool = False # Port for the HTTP server host: str = "127.0.0.1" @@ -280,6 +281,12 @@ class ServerArgs: action="store_true", help="If set, skip init tokenizer and pass input_ids in generate request", ) + parser.add_argument( + "--return-token-ids", + action="store_true", + default=ServerArgs.return_token_ids, + help="Whether to return token IDs in the output, this may introduce additional overhead.", + ) parser.add_argument( "--load-format", type=str, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b48ee7b23..02fe8032e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -44,6 +44,7 @@ suites = { "test_vision_chunked_prefill.py", "test_vision_openai_server.py", "test_session_control.py", + "test_engine_token_ids.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", diff --git a/test/srt/test_engine_token_ids.py b/test/srt/test_engine_token_ids.py new file mode 100644 index 000000000..de0bc6665 --- /dev/null +++ b/test/srt/test_engine_token_ids.py @@ -0,0 +1,59 @@ +import unittest + +from transformers import AutoTokenizer + +import sglang as sgl + + +class TestEngineTokenIds(unittest.TestCase): + def test_token_ids_in_generate(self): + llm = sgl.Engine( + model_path="meta-llama/Meta-Llama-3.1-8B-Instruct", return_token_ids=True + ) + tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Meta-Llama-3.1-8B-Instruct" + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = {"temperature": 0.8, "top_p": 0.95} + outputs = llm.generate(prompts, sampling_params) + + # Hugging Face tokenizer has a start token in its output, + # while SGLang only adds next_token_id in output_ids. + # We remove start token in HF output for comparison. + for prompt, output in zip(prompts, outputs): + hf_input_ids = tokenizer.encode(prompt) + self.assertEqual( + output["input_ids"], + hf_input_ids, + f"Input token IDs mismatch for: {prompt}", + ) + + hf_output_ids = tokenizer.encode(output["text"])[1:] # remove start token + self.assertEqual( + output["output_ids"], + hf_output_ids, + f"Output token IDs mismatch for: {output['text']}", + ) + + self.assertEqual( + len(output["input_ids"]), + output["meta_info"]["prompt_tokens"], + "Prompt token count mismatch", + ) + self.assertEqual( + len(output["output_ids"]), + output["meta_info"]["completion_tokens"], + "Completion token count mismatch", + ) + + llm.shutdown() + + +if __name__ == "__main__": + unittest.main()