[Feature] Get Token IDs with Engine.generate() (#2636)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -181,6 +181,8 @@ class DetokenizerManager:
|
|||||||
finished_reasons=recv_obj.finished_reasons,
|
finished_reasons=recv_obj.finished_reasons,
|
||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
prompt_tokens=recv_obj.prompt_tokens,
|
prompt_tokens=recv_obj.prompt_tokens,
|
||||||
|
origin_input_ids=recv_obj.origin_input_ids,
|
||||||
|
output_ids=recv_obj.output_ids,
|
||||||
completion_tokens=recv_obj.completion_tokens,
|
completion_tokens=recv_obj.completion_tokens,
|
||||||
cached_tokens=recv_obj.cached_tokens,
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||||
|
|||||||
@@ -323,7 +323,9 @@ class BatchTokenIDOut:
|
|||||||
decoded_texts: List[str]
|
decoded_texts: List[str]
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
read_offsets: List[int]
|
read_offsets: List[int]
|
||||||
# Only used when `--skip-tokenizer-init`
|
# Only used when --return-token-ids` is set
|
||||||
|
origin_input_ids: Optional[List[int]]
|
||||||
|
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
|
||||||
output_ids: Optional[List[int]]
|
output_ids: Optional[List[int]]
|
||||||
# Detokenization configs
|
# Detokenization configs
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
@@ -354,6 +356,10 @@ class BatchStrOut:
|
|||||||
# The output decoded strings
|
# The output decoded strings
|
||||||
output_strs: List[str]
|
output_strs: List[str]
|
||||||
|
|
||||||
|
# The token ids
|
||||||
|
origin_input_ids: Optional[List[int]]
|
||||||
|
output_ids: Optional[List[int]]
|
||||||
|
|
||||||
# Token counts
|
# Token counts
|
||||||
prompt_tokens: List[int]
|
prompt_tokens: List[int]
|
||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
|
|||||||
@@ -1218,6 +1218,7 @@ class Scheduler:
|
|||||||
decode_ids_list = []
|
decode_ids_list = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
output_ids = []
|
output_ids = []
|
||||||
|
origin_input_ids = []
|
||||||
|
|
||||||
skip_special_tokens = []
|
skip_special_tokens = []
|
||||||
spaces_between_special_tokens = []
|
spaces_between_special_tokens = []
|
||||||
@@ -1266,8 +1267,14 @@ class Scheduler:
|
|||||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||||
decode_ids_list.append(decode_ids)
|
decode_ids_list.append(decode_ids)
|
||||||
read_offsets.append(read_offset)
|
read_offsets.append(read_offset)
|
||||||
if self.skip_tokenizer_init:
|
if self.skip_tokenizer_init or self.server_args.return_token_ids:
|
||||||
output_ids.append(req.output_ids)
|
output_ids.append(req.output_ids)
|
||||||
|
else:
|
||||||
|
output_ids = None
|
||||||
|
if self.server_args.return_token_ids:
|
||||||
|
origin_input_ids.append(req.origin_input_ids)
|
||||||
|
else:
|
||||||
|
origin_input_ids = None
|
||||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||||
spaces_between_special_tokens.append(
|
spaces_between_special_tokens.append(
|
||||||
req.sampling_params.spaces_between_special_tokens
|
req.sampling_params.spaces_between_special_tokens
|
||||||
@@ -1299,6 +1306,7 @@ class Scheduler:
|
|||||||
decoded_texts,
|
decoded_texts,
|
||||||
decode_ids_list,
|
decode_ids_list,
|
||||||
read_offsets,
|
read_offsets,
|
||||||
|
origin_input_ids,
|
||||||
output_ids,
|
output_ids,
|
||||||
skip_special_tokens,
|
skip_special_tokens,
|
||||||
spaces_between_special_tokens,
|
spaces_between_special_tokens,
|
||||||
|
|||||||
@@ -663,6 +663,13 @@ class TokenizerManager:
|
|||||||
"text": recv_obj.output_strs[i],
|
"text": recv_obj.output_strs[i],
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
|
if self.server_args.return_token_ids:
|
||||||
|
out_dict.update(
|
||||||
|
{
|
||||||
|
"input_ids": recv_obj.origin_input_ids[i],
|
||||||
|
"output_ids": recv_obj.output_ids[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"token_ids": recv_obj.output_ids[i],
|
"token_ids": recv_obj.output_ids[i],
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class ServerArgs:
|
|||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
is_embedding: bool = False
|
is_embedding: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
return_token_ids: bool = False
|
||||||
|
|
||||||
# Port for the HTTP server
|
# Port for the HTTP server
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
@@ -280,6 +281,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-token-ids",
|
||||||
|
action="store_true",
|
||||||
|
default=ServerArgs.return_token_ids,
|
||||||
|
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-format",
|
"--load-format",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ suites = {
|
|||||||
"test_vision_chunked_prefill.py",
|
"test_vision_chunked_prefill.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_session_control.py",
|
"test_session_control.py",
|
||||||
|
"test_engine_token_ids.py",
|
||||||
],
|
],
|
||||||
"nightly": [
|
"nightly": [
|
||||||
"test_nightly_gsm8k_eval.py",
|
"test_nightly_gsm8k_eval.py",
|
||||||
|
|||||||
59
test/srt/test_engine_token_ids.py
Normal file
59
test/srt/test_engine_token_ids.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineTokenIds(unittest.TestCase):
|
||||||
|
def test_token_ids_in_generate(self):
|
||||||
|
llm = sgl.Engine(
|
||||||
|
model_path="meta-llama/Meta-Llama-3.1-8B-Instruct", return_token_ids=True
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Hugging Face tokenizer has a start token in its output,
|
||||||
|
# while SGLang only adds next_token_id in output_ids.
|
||||||
|
# We remove start token in HF output for comparison.
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
hf_input_ids = tokenizer.encode(prompt)
|
||||||
|
self.assertEqual(
|
||||||
|
output["input_ids"],
|
||||||
|
hf_input_ids,
|
||||||
|
f"Input token IDs mismatch for: {prompt}",
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_output_ids = tokenizer.encode(output["text"])[1:] # remove start token
|
||||||
|
self.assertEqual(
|
||||||
|
output["output_ids"],
|
||||||
|
hf_output_ids,
|
||||||
|
f"Output token IDs mismatch for: {output['text']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(output["input_ids"]),
|
||||||
|
output["meta_info"]["prompt_tokens"],
|
||||||
|
"Prompt token count mismatch",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
len(output["output_ids"]),
|
||||||
|
output["meta_info"]["completion_tokens"],
|
||||||
|
"Completion token count mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
llm.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user