From b912de11b0a58330064b3d72db6ea0fad515d468 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 12 Sep 2024 20:47:31 -0700 Subject: [PATCH] Make stop reason a dict instead of str (#1407) --- python/sglang/srt/managers/schedule_batch.py | 43 +++++++++------ python/sglang/srt/managers/tp_worker.py | 6 ++- python/sglang/srt/openai_api/adapter.py | 55 ++++++++++---------- 3 files changed, 60 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6a8e4d9f1..17d13c7a5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -56,7 +56,7 @@ class BaseFinishReason: def __init__(self, is_error: bool = False): self.is_error = is_error - def __str__(self): + def to_json(self): raise NotImplementedError("Subclasses must implement this method") @@ -65,17 +65,11 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason): super().__init__() self.matched = matched - def __str__(self) -> str: - return f"FINISH_MATCHED_TOKEN: {self.matched}" - - -class FINISH_LENGTH(BaseFinishReason): - def __init__(self, length: int): - super().__init__() - self.length = length - - def __str__(self) -> str: - return f"FINISH_LENGTH: {self.length}" + def to_json(self): + return { + "type": "stop", # to match OpenAI API's return value + "matched": self.matched, + } class FINISH_MATCHED_STR(BaseFinishReason): @@ -83,16 +77,33 @@ class FINISH_MATCHED_STR(BaseFinishReason): super().__init__() self.matched = matched - def __str__(self) -> str: - return f"FINISH_MATCHED_STR: {self.matched}" + def to_json(self): + return { + "type": "stop", # to match OpenAI API's return value + "matched": self.matched, + } + + +class FINISH_LENGTH(BaseFinishReason): + def __init__(self, length: int): + super().__init__() + self.length = length + + def to_json(self): + return { + "type": "length", # to match OpenAI API's return value + "length": self.length, + } class FINISH_ABORT(BaseFinishReason): def __init__(self): super().__init__(is_error=True) - def __str__(self) -> str: - return "FINISH_ABORT" + def to_json(self): + return { + "type": "abort", + } class Req: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 096c13108..05619aae1 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -813,7 +813,11 @@ class ModelTpServer: "prompt_tokens": len(req.origin_input_ids), "completion_tokens": len(req.output_ids), "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": str(req.finished_reason), + "finish_reason": ( + req.finished_reason.to_json() + if req.finished_reason is not None + else None + ), } if req.return_logprob: ( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index d1b296e9b..1b8169af6 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {} storage_dir = None -def format_finish_reason(finish_reason) -> Optional[str]: - if finish_reason.startswith("None"): - return None - elif finish_reason.startswith("FINISH_MATCHED"): - return "stop" - elif finish_reason.startswith("FINISH_LENGTH"): - return "length" - elif finish_reason.startswith("FINISH_ABORT"): - return "abort" - else: - return "unknown" - - def create_error_response( message: str, err_type: str = "BadRequestError", @@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): "index": 0, "text": text, "logprobs": logprobs, - "finish_reason": format_finish_reason( - ret_item["meta_info"]["finish_reason"] + "finish_reason": ( + ret_item["meta_info"]["finish_reason"]["type"] + if ret_item["meta_info"]["finish_reason"] + else "" ), } else: @@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): index=idx, text=text, logprobs=logprobs, - finish_reason=format_finish_reason( - ret_item["meta_info"]["finish_reason"] + finish_reason=( + ret_item["meta_info"]["finish_reason"]["type"] + if ret_item["meta_info"]["finish_reason"] + else "" ), ) @@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request): index=index, text=delta, logprobs=logprobs, - finish_reason=format_finish_reason( - content["meta_info"]["finish_reason"] + finish_reason=( + content["meta_info"]["finish_reason"]["type"] + if content["meta_info"]["finish_reason"] + else "" ), ) chunk = CompletionStreamResponse( @@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False): "index": 0, "message": {"role": "assistant", "content": ret_item["text"]}, "logprobs": choice_logprobs, - "finish_reason": format_finish_reason( - ret_item["meta_info"]["finish_reason"] + "finish_reason": ( + ret_item["meta_info"]["finish_reason"]["type"] + if ret_item["meta_info"]["finish_reason"] + else "" ), } else: @@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False): index=idx, message=ChatMessage(role="assistant", content=ret_item["text"]), logprobs=choice_logprobs, - finish_reason=format_finish_reason( - ret_item["meta_info"]["finish_reason"] + finish_reason=( + ret_item["meta_info"]["finish_reason"]["type"] + if ret_item["meta_info"]["finish_reason"] + else "" ), ) @@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(role="assistant"), - finish_reason=format_finish_reason( - content["meta_info"]["finish_reason"] + finish_reason=( + content["meta_info"]["finish_reason"]["type"] + if content["meta_info"]["finish_reason"] + else "" ), logprobs=choice_logprobs, ) @@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(content=delta), - finish_reason=format_finish_reason( - content["meta_info"]["finish_reason"] + finish_reason=( + content["meta_info"]["finish_reason"]["type"] + if content["meta_info"]["finish_reason"] + else "" ), logprobs=choice_logprobs, )