[Feature] Get Token IDs with Engine.generate() (#2636)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Shi Shuai
2024-12-29 20:28:27 +00:00
committed by GitHub
parent b085e06b01
commit 35bdb48557
7 changed files with 92 additions and 2 deletions

View File

@@ -1218,6 +1218,7 @@ class Scheduler:
decode_ids_list = []
read_offsets = []
output_ids = []
origin_input_ids = []
skip_special_tokens = []
spaces_between_special_tokens = []
@@ -1266,8 +1267,14 @@ class Scheduler:
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
read_offsets.append(read_offset)
if self.skip_tokenizer_init:
if self.skip_tokenizer_init or self.server_args.return_token_ids:
output_ids.append(req.output_ids)
else:
output_ids = None
if self.server_args.return_token_ids:
origin_input_ids.append(req.origin_input_ids)
else:
origin_input_ids = None
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
@@ -1299,6 +1306,7 @@ class Scheduler:
decoded_texts,
decode_ids_list,
read_offsets,
origin_input_ids,
output_ids,
skip_special_tokens,
spaces_between_special_tokens,