diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py index 99ec0aa58..753e66c74 100644 --- a/examples/usage/openai_parallel_sample.py +++ b/examples/usage/openai_parallel_sample.py @@ -106,13 +106,12 @@ response = client.chat.completions.create( {"role": "user", "content": "List 3 countries and their capitals."}, ], temperature=0.8, - max_tokens=64, + max_tokens=1, logprobs=True, - n=1, + top_logprobs=3, ) print(response) - # Chat completion response = client.chat.completions.create( model="default", @@ -121,8 +120,34 @@ response = client.chat.completions.create( {"role": "user", "content": "List 3 countries and their capitals."}, ], temperature=0.8, - max_tokens=64, + max_tokens=1, + n=1, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=1, logprobs=True, + top_logprobs=3, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=1, n=4, ) print(response) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 21f38cd22..c52d298d8 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import ( ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, + ChatCompletionTokenLogprob, ChatMessage, + ChoiceLogprobs, CompletionRequest, CompletionResponse, CompletionResponseChoice, @@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import ( FileRequest, FileResponse, LogProbs, + TopLogprob, UsageInfo, ) @@ -70,7 +73,7 @@ class FileMetadata: batch_storage: Dict[str, BatchResponse] = {} file_id_request: Dict[str, FileMetadata] = {} file_id_response: Dict[str, FileResponse] = {} -## map file id to file path in SGlang backend +# map file id to file path in SGlang backend file_id_storage: Dict[str, str] = {} @@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe failed_requests += len(file_request_list) for idx, response in enumerate(responses): - ## the batch_req here can be changed to be named within a batch granularity + # the batch_req here can be changed to be named within a batch granularity response_json = { "id": f"batch_req_{uuid.uuid4()}", "custom_id": file_request_list[idx].get("custom_id"), @@ -333,6 +336,8 @@ def v1_generate_request(all_requests): prompts = [] sampling_params_list = [] + return_logprobs = [] + top_logprobs_nums = [] first_prompt_type = type(all_requests[0].prompt) for request in all_requests: prompt = request.prompt @@ -340,6 +345,10 @@ def v1_generate_request(all_requests): type(prompt) == first_prompt_type ), "All prompts must be of the same type in file input settings" prompts.append(prompt) + return_logprobs.append(request.logprobs is not None and request.logprobs > 0) + top_logprobs_nums.append( + request.logprobs if request.logprobs is not None else 0 + ) sampling_params_list.append( { "temperature": request.temperature, @@ -361,6 +370,8 @@ def v1_generate_request(all_requests): if len(all_requests) == 1: prompt = prompts[0] sampling_params_list = sampling_params_list[0] + return_logprobs = return_logprobs[0] + top_logprobs_nums = top_logprobs_nums[0] if isinstance(prompt, str) or isinstance(prompt[0], str): prompt_kwargs = {"text": prompt} else: @@ -370,15 +381,11 @@ def v1_generate_request(all_requests): prompt_kwargs = {"text": prompts} else: prompt_kwargs = {"input_ids": prompts} - adapted_request = GenerateReqInput( **prompt_kwargs, sampling_params=sampling_params_list, - return_logprob=all_requests[0].logprobs is not None - and all_requests[0].logprobs > 0, - top_logprobs_num=( - all_requests[0].logprobs if all_requests[0].logprobs is not None else 0 - ), + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, return_text_in_logprobs=True, stream=all_requests[0].stream, ) @@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False): logprobs = None if to_file: - ## to make the choise data json serializable + # to make the choise data json serializable choice_data = { "index": 0, "text": text, @@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False): "status_code": 200, "request_id": ret[i]["meta_info"]["id"], "body": { - ## remain the same but if needed we can change that + # remain the same but if needed we can change that "id": ret[i]["meta_info"]["id"], "object": "text_completion", "created": int(time.time()), @@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): texts = [] sampling_params_list = [] image_data_list = [] + return_logprobs = [] + top_logprobs_nums = [] for request in all_requests: # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. @@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): stop = request.stop image_data = None texts.append(prompt) + return_logprobs.append(request.logprobs) + top_logprobs_nums.append(request.top_logprobs) sampling_params_list.append( { "temperature": request.temperature, @@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): texts = texts[0] sampling_params_list = sampling_params_list[0] image_data = image_data_list[0] + return_logprobs = return_logprobs[0] + top_logprobs_nums = top_logprobs_nums[0] adapted_request = GenerateReqInput( text=texts, image_data=image_data, sampling_params=sampling_params_list, - stream=request.stream, + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, + stream=all_requests[0].stream, + return_text_in_logprobs=True, ) if len(all_requests) == 1: return adapted_request, all_requests[0] @@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False): total_completion_tokens = 0 for idx, ret_item in enumerate(ret): + logprobs = False + if isinstance(request, List) and request[idx].logprobs: + logprobs = True + elif (not isinstance(request, List)) and request.logprobs: + logprobs = True + if logprobs: + logprobs = to_openai_style_logprobs( + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + token_logprobs = [] + for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[0].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + choice_logprobs = ChoiceLogprobs(content=token_logprobs) + else: + choice_logprobs = None prompt_tokens = ret_item["meta_info"]["prompt_tokens"] completion_tokens = ret_item["meta_info"]["completion_tokens"] if to_file: - ## to make the choice data json serializable + # to make the choice data json serializable choice_data = { "index": 0, "message": {"role": "assistant", "content": ret_item["text"]}, - "logprobs": None, + "logprobs": choice_logprobs, "finish_reason": ret_item["meta_info"]["finish_reason"], } else: choice_data = ChatCompletionResponseChoice( index=idx, message=ChatMessage(role="assistant", content=ret_item["text"]), + logprobs=choice_logprobs, finish_reason=ret_item["meta_info"]["finish_reason"], ) choices.append(choice_data) - total_prompt_tokens = prompt_tokens + total_prompt_tokens += prompt_tokens total_completion_tokens += completion_tokens if to_file: responses = [] @@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False): "status_code": 200, "request_id": ret[i]["meta_info"]["id"], "body": { - ## remain the same but if needed we can change that + # remain the same but if needed we can change that "id": ret[i]["meta_info"]["id"], "object": "chat.completion", "created": int(time.time()), diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 853165e34..c1d2a8cf3 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -54,6 +54,24 @@ class LogProbs(BaseModel): top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) +class TopLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + + +class ChatCompletionTokenLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + top_logprobs: List[TopLogprob] + + +class ChoiceLogprobs(BaseModel): + # build for v1/chat/completions response + content: List[ChatCompletionTokenLogprob] + + class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 @@ -239,8 +257,8 @@ class ChatMessage(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: str class ChatCompletionResponse(BaseModel):