Simplify stream_output (#2398)
This commit is contained in:
@@ -17,7 +17,7 @@ import dataclasses
|
||||
import logging
|
||||
import signal
|
||||
from collections import OrderedDict
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -76,17 +76,25 @@ class DetokenizerManager:
|
||||
|
||||
self.decode_status = LimitedCapacityDict()
|
||||
|
||||
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
|
||||
if no_stop_trim:
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
):
|
||||
if no_stop_trim or not finished_reason:
|
||||
return output
|
||||
|
||||
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
|
||||
pos = output.find(finished_reason.matched)
|
||||
matched = finished_reason.get("matched", None)
|
||||
if not matched:
|
||||
return output
|
||||
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
|
||||
# Trim stop str.
|
||||
if isinstance(matched, str) and isinstance(output, str):
|
||||
pos = output.find(matched)
|
||||
return output[:pos] if pos != -1 else output
|
||||
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
|
||||
output, list
|
||||
):
|
||||
|
||||
# Trim stop token.
|
||||
if isinstance(matched, int) and isinstance(output, list):
|
||||
assert len(output) > 0
|
||||
return output[:-1]
|
||||
return output
|
||||
@@ -125,9 +133,9 @@ class DetokenizerManager:
|
||||
s.decode_ids = recv_obj.decode_ids[i]
|
||||
|
||||
read_ids.append(
|
||||
self.trim_eos(
|
||||
self.trim_matched_stop(
|
||||
s.decode_ids[s.surr_offset :],
|
||||
recv_obj.finished_reason[i],
|
||||
recv_obj.finished_reasons[i],
|
||||
recv_obj.no_stop_trim[i],
|
||||
)
|
||||
)
|
||||
@@ -150,7 +158,7 @@ class DetokenizerManager:
|
||||
for i in range(bs):
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
new_text = read_texts[i][len(surr_texts[i]) :]
|
||||
if recv_obj.finished_reason[i] is None:
|
||||
if recv_obj.finished_reasons[i] is None:
|
||||
# Streaming chunk: update the decode status
|
||||
if len(new_text) > 0 and not new_text.endswith("<EFBFBD>"):
|
||||
s.decoded_text = s.decoded_text + new_text
|
||||
@@ -161,9 +169,9 @@ class DetokenizerManager:
|
||||
new_text = find_printable_text(new_text)
|
||||
|
||||
output_strs.append(
|
||||
self.trim_eos(
|
||||
self.trim_matched_stop(
|
||||
s.decoded_text + new_text,
|
||||
recv_obj.finished_reason[i],
|
||||
recv_obj.finished_reasons[i],
|
||||
recv_obj.no_stop_trim[i],
|
||||
)
|
||||
)
|
||||
@@ -171,9 +179,20 @@ class DetokenizerManager:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
meta_info=recv_obj.meta_info,
|
||||
finished_reason=recv_obj.finished_reason,
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
|
||||
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user