Decode Incrementally (#517)
This commit is contained in:
@@ -39,30 +39,24 @@ class DetokenizerManager:
|
||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||
|
||||
output_tokens = recv_obj.output_tokens
|
||||
|
||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||
output_strs = self.tokenizer.batch_decode(
|
||||
output_tokens,
|
||||
surr_texts = self.tokenizer.batch_decode(
|
||||
recv_obj.surr_output_ids,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
||||
0
|
||||
],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||
)
|
||||
read_texts = self.tokenizer.batch_decode(
|
||||
recv_obj.read_output_ids,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||
)
|
||||
|
||||
# Trim stop str
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
for i in range(len(output_strs)):
|
||||
if len(output_tokens[i]) > 0:
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
||||
int(output_tokens[i][0])
|
||||
)
|
||||
if not isinstance(first_token, str):
|
||||
first_token = first_token.decode("utf-8", errors="ignore")
|
||||
if first_token.startswith("▁"):
|
||||
output_strs[i] = " " + output_strs[i]
|
||||
|
||||
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
||||
output_strs = []
|
||||
for i in range(len(recv_obj.rids)):
|
||||
new_text = read_texts[i][len(surr_texts[i]) :]
|
||||
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
||||
|
||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||
|
||||
Reference in New Issue
Block a user