[Fix] fix eos trim inconsistency (#1650)
This commit is contained in:
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import zmq
|
||||
|
||||
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||
from sglang.utils import find_printable_text, get_exception_traceback
|
||||
@@ -75,6 +75,21 @@ class DetokenizerManager:
|
||||
|
||||
self.decode_status = LimitedCapacityDict()
|
||||
|
||||
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
|
||||
if no_eos_trim:
|
||||
return output
|
||||
|
||||
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
|
||||
pos = output.find(finished_reason.matched)
|
||||
return output[:pos] if pos != -1 else output
|
||||
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
|
||||
output, list
|
||||
):
|
||||
assert len(output) > 0
|
||||
return output[:-1]
|
||||
return output
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
@@ -122,7 +137,13 @@ class DetokenizerManager:
|
||||
s = self.decode_status[rid]
|
||||
s.decode_ids = recv_obj.decode_ids[i]
|
||||
|
||||
read_ids.append(s.decode_ids[s.surr_offset :])
|
||||
read_ids.append(
|
||||
self.trim_eos(
|
||||
s.decode_ids[s.surr_offset :],
|
||||
recv_obj.finished_reason[i],
|
||||
recv_obj.no_eos_trim[i],
|
||||
)
|
||||
)
|
||||
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
||||
|
||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||
@@ -152,13 +173,13 @@ class DetokenizerManager:
|
||||
else:
|
||||
new_text = find_printable_text(new_text)
|
||||
|
||||
output_strs.append(s.decoded_text + new_text)
|
||||
|
||||
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||
if pos != -1:
|
||||
output_strs[i] = output_strs[i][:pos]
|
||||
output_strs.append(
|
||||
self.trim_eos(
|
||||
s.decoded_text + new_text,
|
||||
recv_obj.finished_reason[i],
|
||||
recv_obj.no_eos_trim[i],
|
||||
)
|
||||
)
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
|
||||
@@ -295,6 +295,7 @@ class BatchTokenIDOut:
|
||||
spaces_between_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
finished_reason: List[BaseFinishReason]
|
||||
no_eos_trim: List[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -883,6 +883,7 @@ class Scheduler:
|
||||
output_read_offsets = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_eos_trim = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
unfinished_indices = []
|
||||
@@ -914,6 +915,7 @@ class Scheduler:
|
||||
output_spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_eos_trim.append(req.sampling_params.no_eos_trim)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
@@ -961,6 +963,7 @@ class Scheduler:
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_eos_trim,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
|
||||
Reference in New Issue
Block a user