Simplify stream_output (#2398)

This commit is contained in:
Lianmin Zheng
2024-12-08 12:27:13 -08:00
committed by GitHub
parent f62055b528
commit a6ca736c8e
9 changed files with 426 additions and 290 deletions

View File

@@ -17,7 +17,7 @@ import dataclasses
import logging
import signal
from collections import OrderedDict
from typing import List, Union
from typing import Dict, List, Union
import psutil
import setproctitle
@@ -76,17 +76,25 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
if no_stop_trim:
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
if no_stop_trim or not finished_reason:
return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
matched = finished_reason.get("matched", None)
if not matched:
return output
# TODO(lmzheng): handle the case where multiple stop strs are hit
# Trim stop str.
if isinstance(matched, str) and isinstance(output, str):
pos = output.find(matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
# Trim stop token.
if isinstance(matched, int) and isinstance(output, list):
assert len(output) > 0
return output[:-1]
return output
@@ -125,9 +133,9 @@ class DetokenizerManager:
s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(
self.trim_eos(
self.trim_matched_stop(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
)
@@ -150,7 +158,7 @@ class DetokenizerManager:
for i in range(bs):
s = self.decode_status[recv_obj.rids[i]]
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reason[i] is None:
if recv_obj.finished_reasons[i] is None:
# Streaming chunk: update the decode status
if len(new_text) > 0 and not new_text.endswith("<EFBFBD>"):
s.decoded_text = s.decoded_text + new_text
@@ -161,9 +169,9 @@ class DetokenizerManager:
new_text = find_printable_text(new_text)
output_strs.append(
self.trim_eos(
self.trim_matched_stop(
s.decoded_text + new_text,
recv_obj.finished_reason[i],
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
)
@@ -171,9 +179,20 @@ class DetokenizerManager:
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
)
)