From 3a79613c28319030a5fe7fe22284b178f56984e1 Mon Sep 17 00:00:00 2001 From: yichuan~ <73766326+yichuan520030910320@users.noreply.github.com> Date: Thu, 8 Aug 2024 17:41:57 +0800 Subject: [PATCH] support more optioin about usage in stream mode (#985) Co-authored-by: Ying Sheng --- python/sglang/srt/managers/schedule_batch.py | 4 +- python/sglang/srt/openai_api/adapter.py | 90 +++++++++++++++----- python/sglang/srt/openai_api/protocol.py | 10 ++- test/srt/test_openai_server.py | 20 ++++- 4 files changed, 96 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5f026812a..714777dc1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -217,7 +217,9 @@ class Req: return if len(self.output_ids) >= self.sampling_params.max_new_tokens: - self.finished_reason = FINISH_LENGTH(len(self.output_ids)) + self.finished_reason = FINISH_LENGTH( + length=self.sampling_params.max_new_tokens + ) return if ( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index c2cdfefe3..5e21d67e4 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -84,6 +84,19 @@ file_id_storage: Dict[str, str] = {} storage_dir = None +def format_finish_reason(finish_reason) -> Optional[str]: + if finish_reason.startswith("None"): + return None + elif finish_reason.startswith("FINISH_MATCHED"): + return "stop" + elif finish_reason.startswith("FINISH_LENGTH"): + return "length" + elif finish_reason.startswith("FINISH_ABORT"): + return "abort" + else: + return "unknown" + + def create_error_response( message: str, err_type: str = "BadRequestError", @@ -486,14 +499,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): "index": 0, "text": text, "logprobs": logprobs, - "finish_reason": ret_item["meta_info"]["finish_reason"], + "finish_reason": format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), } else: choice_data = CompletionResponseChoice( index=idx, text=text, logprobs=logprobs, - finish_reason=ret_item["meta_info"]["finish_reason"], + finish_reason=format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), ) choices.append(choice_data) @@ -608,20 +625,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request): index=0, text=delta, logprobs=logprobs, - finish_reason=content["meta_info"]["finish_reason"], + finish_reason=format_finish_reason( + content["meta_info"]["finish_reason"] + ), ) chunk = CompletionStreamResponse( id=content["meta_info"]["id"], object="text_completion", choices=[choice_data], model=request.model, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), ) yield f"data: {chunk.model_dump_json()}\n\n" + if request.stream_options and request.stream_options.include_usage: + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + final_usage_chunk = CompletionStreamResponse( + id=str(uuid.uuid4().hex), + choices=[], + model=request.model, + usage=usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) + yield f"data: {final_usage_data}\n\n" except ValueError as e: error = create_streaming_error_response(str(e)) yield f"data: {error}\n\n" @@ -776,14 +807,18 @@ def v1_chat_generate_response(request, ret, to_file=False): "index": 0, "message": {"role": "assistant", "content": ret_item["text"]}, "logprobs": choice_logprobs, - "finish_reason": ret_item["meta_info"]["finish_reason"], + "finish_reason": format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), } else: choice_data = ChatCompletionResponseChoice( index=idx, message=ChatMessage(role="assistant", content=ret_item["text"]), logprobs=choice_logprobs, - finish_reason=ret_item["meta_info"]["finish_reason"], + finish_reason=format_finish_reason( + ret_item["meta_info"]["finish_reason"] + ), ) choices.append(choice_data) @@ -900,18 +935,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), - finish_reason=content["meta_info"]["finish_reason"], + finish_reason=format_finish_reason( + content["meta_info"]["finish_reason"] + ), logprobs=choice_logprobs, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], choices=[choice_data], model=request.model, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), ) yield f"data: {chunk.model_dump_json()}\n\n" @@ -921,20 +953,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=delta), - finish_reason=content["meta_info"]["finish_reason"], + finish_reason=format_finish_reason( + content["meta_info"]["finish_reason"] + ), logprobs=choice_logprobs, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], choices=[choice_data], model=request.model, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), ) yield f"data: {chunk.model_dump_json()}\n\n" + if request.stream_options and request.stream_options.include_usage: + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + final_usage_chunk = ChatCompletionStreamResponse( + id=str(uuid.uuid4().hex), + choices=[], + model=request.model, + usage=usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) + yield f"data: {final_usage_data}\n\n" except ValueError as e: error = create_streaming_error_response(str(e)) yield f"data: {error}\n\n" diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 3a91c12e8..2910dd5cd 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -78,6 +78,10 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class StreamOptions(BaseModel): + include_usage: Optional[bool] = False + + class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded @@ -149,6 +153,7 @@ class CompletionRequest(BaseModel): seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 @@ -188,7 +193,7 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] - usage: UsageInfo + usage: Optional[UsageInfo] = None class ChatCompletionMessageGenericParam(BaseModel): @@ -247,6 +252,7 @@ class ChatCompletionRequest(BaseModel): seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 user: Optional[str] = None @@ -294,6 +300,7 @@ class ChatCompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = None class EmbeddingRequest(BaseModel): @@ -310,3 +317,4 @@ class EmbeddingResponse(BaseModel): index: str embedding: List[float] = None object: str = "embedding" + usage: Optional[UsageInfo] = None diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index f86dc0650..f8f6ca632 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -98,10 +98,17 @@ class TestOpenAIServer(unittest.TestCase): echo=echo, logprobs=logprobs, stream=True, + stream_options={"include_usage": True}, ) first = True for response in generator: + usage = response.usage + if usage is not None: + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 + continue if logprobs: assert response.choices[0].logprobs assert isinstance(response.choices[0].logprobs.tokens[0], str) @@ -122,12 +129,8 @@ class TestOpenAIServer(unittest.TestCase): prompt ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}" first = False - assert response.id assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 def run_chat_completion(self, logprobs, parallel_sample_num): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -179,11 +182,20 @@ class TestOpenAIServer(unittest.TestCase): logprobs=logprobs is not None and logprobs > 0, top_logprobs=logprobs, stream=True, + stream_options={"include_usage": True}, ) is_first = True for response in generator: + usage = response.usage + if usage is not None: + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 + continue + data = response.choices[0].delta + if is_first: data.role == "assistant" is_first = False