Decode Incrementally (#517)

This commit is contained in:
Liangsheng Yin
2024-06-12 14:39:12 +08:00
committed by GitHub
parent 111991fe23
commit 9c902b1954
8 changed files with 345 additions and 135 deletions

View File

@@ -39,30 +39,24 @@ class DetokenizerManager:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut)
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
surr_texts = self.tokenizer.batch_decode(
recv_obj.surr_output_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
recv_obj.read_output_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith(""):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
output_strs = []
for i in range(len(recv_obj.rids)):
new_text = read_texts[i][len(surr_texts[i]) :]
output_strs.append(recv_obj.decoded_texts[i] + new_text)
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)