[Fix] fix eos trim inconsistency (#1650)

This commit is contained in:
Ying Sheng
2024-10-13 01:07:09 -07:00
committed by GitHub
parent c3f2fc5a7a
commit 4876117171
7 changed files with 77 additions and 27 deletions

View File

@@ -18,7 +18,7 @@ limitations under the License.
import dataclasses
import logging
from collections import OrderedDict
from typing import List
from typing import List, Union
import zmq
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,6 +75,21 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
if no_eos_trim:
return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
assert len(output) > 0
return output[:-1]
return output
def event_loop(self):
"""The event loop that handles requests"""
@@ -122,7 +137,13 @@ class DetokenizerManager:
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(s.decode_ids[s.surr_offset :])
read_ids.append(
self.trim_eos(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
@@ -152,13 +173,13 @@ class DetokenizerManager:
else:
new_text = find_printable_text(new_text)
output_strs.append(s.decoded_text + new_text)
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
if pos != -1:
output_strs[i] = output_strs[i][:pos]
output_strs.append(
self.trim_eos(
s.decoded_text + new_text,
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
self.send_to_tokenizer.send_pyobj(
BatchStrOut(

View File

@@ -295,6 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_eos_trim: List[bool]
@dataclass

View File

@@ -883,6 +883,7 @@ class Scheduler:
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_eos_trim = []
else: # embedding or reward model
output_embeddings = []
unfinished_indices = []
@@ -914,6 +915,7 @@ class Scheduler:
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
output_no_eos_trim.append(req.sampling_params.no_eos_trim)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
@@ -961,6 +963,7 @@ class Scheduler:
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
output_no_eos_trim,
)
)
else: # embedding or reward model