[Feature] Get Token IDs with Engine.generate() (#2636)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Shi Shuai
2024-12-29 20:28:27 +00:00
committed by GitHub
parent b085e06b01
commit 35bdb48557
7 changed files with 92 additions and 2 deletions

View File

@@ -181,6 +181,8 @@ class DetokenizerManager:
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
prompt_tokens=recv_obj.prompt_tokens,
origin_input_ids=recv_obj.origin_input_ids,
output_ids=recv_obj.output_ids,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,

View File

@@ -323,7 +323,9 @@ class BatchTokenIDOut:
decoded_texts: List[str]
decode_ids: List[int]
read_offsets: List[int]
# Only used when `--skip-tokenizer-init`
# Only used when --return-token-ids` is set
origin_input_ids: Optional[List[int]]
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
output_ids: Optional[List[int]]
# Detokenization configs
skip_special_tokens: List[bool]
@@ -354,6 +356,10 @@ class BatchStrOut:
# The output decoded strings
output_strs: List[str]
# The token ids
origin_input_ids: Optional[List[int]]
output_ids: Optional[List[int]]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]

View File

@@ -1218,6 +1218,7 @@ class Scheduler:
decode_ids_list = []
read_offsets = []
output_ids = []
origin_input_ids = []
skip_special_tokens = []
spaces_between_special_tokens = []
@@ -1266,8 +1267,14 @@ class Scheduler:
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
read_offsets.append(read_offset)
if self.skip_tokenizer_init:
if self.skip_tokenizer_init or self.server_args.return_token_ids:
output_ids.append(req.output_ids)
else:
output_ids = None
if self.server_args.return_token_ids:
origin_input_ids.append(req.origin_input_ids)
else:
origin_input_ids = None
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
@@ -1299,6 +1306,7 @@ class Scheduler:
decoded_texts,
decode_ids_list,
read_offsets,
origin_input_ids,
output_ids,
skip_special_tokens,
spaces_between_special_tokens,

View File

@@ -663,6 +663,13 @@ class TokenizerManager:
"text": recv_obj.output_strs[i],
"meta_info": meta_info,
}
if self.server_args.return_token_ids:
out_dict.update(
{
"input_ids": recv_obj.origin_input_ids[i],
"output_ids": recv_obj.output_ids[i],
}
)
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],

View File

@@ -54,6 +54,7 @@ class ServerArgs:
chat_template: Optional[str] = None
is_embedding: bool = False
revision: Optional[str] = None
return_token_ids: bool = False
# Port for the HTTP server
host: str = "127.0.0.1"
@@ -280,6 +281,12 @@ class ServerArgs:
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument(
"--return-token-ids",
action="store_true",
default=ServerArgs.return_token_ids,
help="Whether to return token IDs in the output, this may introduce additional overhead.",
)
parser.add_argument(
"--load-format",
type=str,