Improve: Token-In Token-Out Usage for RLHF (#2843)
This commit is contained in:
@@ -181,8 +181,6 @@ class DetokenizerManager:
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
origin_input_ids=recv_obj.origin_input_ids,
|
||||
output_ids=recv_obj.output_ids,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||
|
||||
@@ -323,9 +323,7 @@ class BatchTokenIDOut:
|
||||
decoded_texts: List[str]
|
||||
decode_ids: List[int]
|
||||
read_offsets: List[int]
|
||||
# Only used when --return-token-ids` is set
|
||||
origin_input_ids: Optional[List[int]]
|
||||
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
|
||||
# Only used when `--skip-tokenizer-init` is on
|
||||
output_ids: Optional[List[int]]
|
||||
# Detokenization configs
|
||||
skip_special_tokens: List[bool]
|
||||
@@ -356,10 +354,6 @@ class BatchStrOut:
|
||||
# The output decoded strings
|
||||
output_strs: List[str]
|
||||
|
||||
# The token ids
|
||||
origin_input_ids: Optional[List[int]]
|
||||
output_ids: Optional[List[int]]
|
||||
|
||||
# Token counts
|
||||
# real input and output tokens can be get from
|
||||
# origin_input_ids and output_ids by enabling --return_token_ids
|
||||
|
||||
@@ -1253,7 +1253,6 @@ class Scheduler:
|
||||
decode_ids_list = []
|
||||
read_offsets = []
|
||||
output_ids = []
|
||||
origin_input_ids = []
|
||||
|
||||
skip_special_tokens = []
|
||||
spaces_between_special_tokens = []
|
||||
@@ -1305,14 +1304,8 @@ class Scheduler:
|
||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||
decode_ids_list.append(decode_ids)
|
||||
read_offsets.append(read_offset)
|
||||
if self.skip_tokenizer_init or self.server_args.return_token_ids:
|
||||
if self.skip_tokenizer_init:
|
||||
output_ids.append(req.output_ids)
|
||||
else:
|
||||
output_ids = None
|
||||
if self.server_args.return_token_ids:
|
||||
origin_input_ids.append(req.origin_input_ids)
|
||||
else:
|
||||
origin_input_ids = None
|
||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||
spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
@@ -1344,7 +1337,6 @@ class Scheduler:
|
||||
decoded_texts,
|
||||
decode_ids_list,
|
||||
read_offsets,
|
||||
origin_input_ids,
|
||||
output_ids,
|
||||
skip_special_tokens,
|
||||
spaces_between_special_tokens,
|
||||
|
||||
@@ -663,13 +663,6 @@ class TokenizerManager:
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
if self.server_args.return_token_ids:
|
||||
out_dict.update(
|
||||
{
|
||||
"input_ids": recv_obj.origin_input_ids[i],
|
||||
"output_ids": recv_obj.output_ids[i],
|
||||
}
|
||||
)
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
|
||||
@@ -55,7 +55,6 @@ class ServerArgs:
|
||||
is_embedding: bool = False
|
||||
revision: Optional[str] = None
|
||||
skip_tokenizer_init: bool = False
|
||||
return_token_ids: bool = False
|
||||
|
||||
# Port for the HTTP server
|
||||
host: str = "127.0.0.1"
|
||||
@@ -296,6 +295,11 @@ class ServerArgs:
|
||||
"tokenizer if available, and 'slow' will "
|
||||
"always use the slow tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-tokenizer-init",
|
||||
action="store_true",
|
||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
type=str,
|
||||
@@ -404,18 +408,6 @@ class ServerArgs:
|
||||
"name, a tag name, or a commit id. If unspecified, will use "
|
||||
"the default version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-tokenizer-init",
|
||||
action="store_true",
|
||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--return-token-ids",
|
||||
action="store_true",
|
||||
default=ServerArgs.return_token_ids,
|
||||
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
||||
)
|
||||
|
||||
# Memory and scheduling
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
|
||||
Reference in New Issue
Block a user