[Fix] fix eos trim inconsistency (#1650)
This commit is contained in:
@@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||||
from sglang.utils import find_printable_text, get_exception_traceback
|
from sglang.utils import find_printable_text, get_exception_traceback
|
||||||
@@ -75,6 +75,21 @@ class DetokenizerManager:
|
|||||||
|
|
||||||
self.decode_status = LimitedCapacityDict()
|
self.decode_status = LimitedCapacityDict()
|
||||||
|
|
||||||
|
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
|
||||||
|
if no_eos_trim:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||||
|
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
|
||||||
|
pos = output.find(finished_reason.matched)
|
||||||
|
return output[:pos] if pos != -1 else output
|
||||||
|
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
|
||||||
|
output, list
|
||||||
|
):
|
||||||
|
assert len(output) > 0
|
||||||
|
return output[:-1]
|
||||||
|
return output
|
||||||
|
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
"""The event loop that handles requests"""
|
"""The event loop that handles requests"""
|
||||||
|
|
||||||
@@ -122,7 +137,13 @@ class DetokenizerManager:
|
|||||||
s = self.decode_status[rid]
|
s = self.decode_status[rid]
|
||||||
s.decode_ids = recv_obj.decode_ids[i]
|
s.decode_ids = recv_obj.decode_ids[i]
|
||||||
|
|
||||||
read_ids.append(s.decode_ids[s.surr_offset :])
|
read_ids.append(
|
||||||
|
self.trim_eos(
|
||||||
|
s.decode_ids[s.surr_offset :],
|
||||||
|
recv_obj.finished_reason[i],
|
||||||
|
recv_obj.no_eos_trim[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
||||||
|
|
||||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||||
@@ -152,13 +173,13 @@ class DetokenizerManager:
|
|||||||
else:
|
else:
|
||||||
new_text = find_printable_text(new_text)
|
new_text = find_printable_text(new_text)
|
||||||
|
|
||||||
output_strs.append(s.decoded_text + new_text)
|
output_strs.append(
|
||||||
|
self.trim_eos(
|
||||||
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
s.decoded_text + new_text,
|
||||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
recv_obj.finished_reason[i],
|
||||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
recv_obj.no_eos_trim[i],
|
||||||
if pos != -1:
|
)
|
||||||
output_strs[i] = output_strs[i][:pos]
|
)
|
||||||
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
BatchStrOut(
|
BatchStrOut(
|
||||||
|
|||||||
@@ -295,6 +295,7 @@ class BatchTokenIDOut:
|
|||||||
spaces_between_special_tokens: List[bool]
|
spaces_between_special_tokens: List[bool]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
no_eos_trim: List[bool]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -883,6 +883,7 @@ class Scheduler:
|
|||||||
output_read_offsets = []
|
output_read_offsets = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
|
output_no_eos_trim = []
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
unfinished_indices = []
|
unfinished_indices = []
|
||||||
@@ -914,6 +915,7 @@ class Scheduler:
|
|||||||
output_spaces_between_special_tokens.append(
|
output_spaces_between_special_tokens.append(
|
||||||
req.sampling_params.spaces_between_special_tokens
|
req.sampling_params.spaces_between_special_tokens
|
||||||
)
|
)
|
||||||
|
output_no_eos_trim.append(req.sampling_params.no_eos_trim)
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
@@ -961,6 +963,7 @@ class Scheduler:
|
|||||||
output_spaces_between_special_tokens,
|
output_spaces_between_special_tokens,
|
||||||
output_meta_info,
|
output_meta_info,
|
||||||
output_finished_reason,
|
output_finished_reason,
|
||||||
|
output_no_eos_trim,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
|
|||||||
@@ -493,7 +493,13 @@ def v1_generate_request(
|
|||||||
top_logprobs_nums.append(
|
top_logprobs_nums.append(
|
||||||
request.logprobs if request.logprobs is not None else 0
|
request.logprobs if request.logprobs is not None else 0
|
||||||
)
|
)
|
||||||
sampling_params_list.append(
|
sampling_params = []
|
||||||
|
if isinstance(request.no_eos_trim, list):
|
||||||
|
num_reqs = len(request.prompt)
|
||||||
|
else:
|
||||||
|
num_reqs = 1
|
||||||
|
for i in range(num_reqs):
|
||||||
|
sampling_params.append(
|
||||||
{
|
{
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
@@ -508,8 +514,17 @@ def v1_generate_request(
|
|||||||
"json_schema": request.json_schema,
|
"json_schema": request.json_schema,
|
||||||
"n": request.n,
|
"n": request.n,
|
||||||
"ignore_eos": request.ignore_eos,
|
"ignore_eos": request.ignore_eos,
|
||||||
|
"no_eos_trim": (
|
||||||
|
request.no_eos_trim
|
||||||
|
if not isinstance(request.no_eos_trim, list)
|
||||||
|
else request.no_eos_trim[i]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if num_reqs == 1:
|
||||||
|
sampling_params_list.append(sampling_params[0])
|
||||||
|
else:
|
||||||
|
sampling_params_list.append(sampling_params)
|
||||||
|
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
prompt = prompts[0]
|
prompt = prompts[0]
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
|
|||||||
min_tokens: int = 0
|
min_tokens: int = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
no_eos_trim: Union[bool, List[bool]] = False
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class SamplingParams:
|
|||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
json_schema: Optional[str] = None,
|
json_schema: Optional[str] = None,
|
||||||
|
no_eos_trim: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -60,6 +61,7 @@ class SamplingParams:
|
|||||||
self.regex = regex
|
self.regex = regex
|
||||||
self.n = n
|
self.n = n
|
||||||
self.json_schema = json_schema
|
self.json_schema = json_schema
|
||||||
|
self.no_eos_trim = no_eos_trim
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
|
|||||||
@@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|||||||
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
||||||
step_counter += 1
|
step_counter += 1
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def first_rank_print(*args, **kwargs):
|
||||||
|
if torch.cuda.current_device() == 0:
|
||||||
|
print(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user