[Fix] fix eos trim inconsistency (#1650)

This commit is contained in:
Ying Sheng
2024-10-13 01:07:09 -07:00
committed by GitHub
parent c3f2fc5a7a
commit 4876117171
7 changed files with 77 additions and 27 deletions

View File

@@ -18,7 +18,7 @@ limitations under the License.
import dataclasses import dataclasses
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List, Union
import zmq import zmq
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
UpdateWeightReqOutput, UpdateWeightReqOutput,
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,6 +75,21 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict() self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
if no_eos_trim:
return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
assert len(output) > 0
return output[:-1]
return output
def event_loop(self): def event_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
@@ -122,7 +137,13 @@ class DetokenizerManager:
s = self.decode_status[rid] s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i] s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(s.decode_ids[s.surr_offset :]) read_ids.append(
self.trim_eos(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
@@ -152,13 +173,13 @@ class DetokenizerManager:
else: else:
new_text = find_printable_text(new_text) new_text = find_printable_text(new_text)
output_strs.append(s.decoded_text + new_text) output_strs.append(
self.trim_eos(
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit s.decoded_text + new_text,
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): recv_obj.finished_reason[i],
pos = output_strs[i].find(recv_obj.finished_reason[i].matched) recv_obj.no_eos_trim[i],
if pos != -1: )
output_strs[i] = output_strs[i][:pos] )
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
BatchStrOut( BatchStrOut(

View File

@@ -295,6 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
no_eos_trim: List[bool]
@dataclass @dataclass

View File

@@ -883,6 +883,7 @@ class Scheduler:
output_read_offsets = [] output_read_offsets = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_no_eos_trim = []
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
unfinished_indices = [] unfinished_indices = []
@@ -914,6 +915,7 @@ class Scheduler:
output_spaces_between_special_tokens.append( output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
output_no_eos_trim.append(req.sampling_params.no_eos_trim)
meta_info = { meta_info = {
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
@@ -961,6 +963,7 @@ class Scheduler:
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
output_finished_reason, output_finished_reason,
output_no_eos_trim,
) )
) )
else: # embedding or reward model else: # embedding or reward model

View File

@@ -493,7 +493,13 @@ def v1_generate_request(
top_logprobs_nums.append( top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0 request.logprobs if request.logprobs is not None else 0
) )
sampling_params_list.append( sampling_params = []
if isinstance(request.no_eos_trim, list):
num_reqs = len(request.prompt)
else:
num_reqs = 1
for i in range(num_reqs):
sampling_params.append(
{ {
"temperature": request.temperature, "temperature": request.temperature,
"max_new_tokens": request.max_tokens, "max_new_tokens": request.max_tokens,
@@ -508,8 +514,17 @@ def v1_generate_request(
"json_schema": request.json_schema, "json_schema": request.json_schema,
"n": request.n, "n": request.n,
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
"no_eos_trim": (
request.no_eos_trim
if not isinstance(request.no_eos_trim, list)
else request.no_eos_trim[i]
),
} }
) )
if num_reqs == 1:
sampling_params_list.append(sampling_params[0])
else:
sampling_params_list.append(sampling_params)
if len(all_requests) == 1: if len(all_requests) == 1:
prompt = prompts[0] prompt = prompts[0]

View File

@@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
min_tokens: int = 0 min_tokens: int = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
no_eos_trim: Union[bool, List[bool]] = False
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):

View File

@@ -40,6 +40,7 @@ class SamplingParams:
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1, n: int = 1,
json_schema: Optional[str] = None, json_schema: Optional[str] = None,
no_eos_trim: bool = False,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
@@ -60,6 +61,7 @@ class SamplingParams:
self.regex = regex self.regex = regex
self.n = n self.n = n
self.json_schema = json_schema self.json_schema = json_schema
self.no_eos_trim = no_eos_trim
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:

View File

@@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json") prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
step_counter += 1 step_counter += 1
return result return result
def first_rank_print(*args, **kwargs):
if torch.cuda.current_device() == 0:
print(*args, **kwargs)
else:
pass