From d53dcf9c989fe4badcfbeb9d598adb7a3b6c9ab3 Mon Sep 17 00:00:00 2001 From: yichuan~ <73766326+yichuan520030910320@users.noreply.github.com> Date: Mon, 5 Aug 2024 07:43:09 +0800 Subject: [PATCH] Support more OpenAI API test (#916) --- python/sglang/srt/managers/io_struct.py | 13 +- .../sglang/srt/managers/tokenizer_manager.py | 41 +++-- python/sglang/srt/openai_api/adapter.py | 140 +++++++++++++++--- python/sglang/srt/openai_api/protocol.py | 2 +- test/srt/test_openai_server.py | 93 +++++++++--- 5 files changed, 230 insertions(+), 59 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 036837a37..5aa767d58 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -92,7 +92,7 @@ class GenerateReqInput: for element in parallel_sample_num_list ) if parallel_sample_num > 1 and (not all_equal): - ## TODO cope with the case that the parallel_sample_num is different for different samples + # TODO cope with the case that the parallel_sample_num is different for different samples raise ValueError( "The parallel_sample_num should be the same for all samples in sample params." ) @@ -103,14 +103,19 @@ class GenerateReqInput: if parallel_sample_num != 1: # parallel sampling +1 represents the original prefill stage num = parallel_sample_num + 1 - if isinstance(self.text, List): - ## suppot batch operation + if isinstance(self.text, list): + # suppot batch operation self.batch_size = len(self.text) num = num * len(self.text) + elif isinstance(self.input_ids, list) and isinstance( + self.input_ids[0], list + ): + self.batch_size = len(self.input_ids) + num = num * len(self.input_ids) else: self.batch_size = 1 else: - ## support select operation + # support select operation num = len(self.text) if self.text is not None else len(self.input_ids) self.batch_size = num diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6cceafda9..e44122bf1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -153,8 +153,9 @@ class TokenizerManager: async def _handle_single_request( self, obj, request, index=None, is_cache_for_prefill=False ): - if not is_cache_for_prefill: - not_use_index = not (index is not None) + if not is_cache_for_prefill: # The normal case with a single prompt + not_use_index = index is None + rid = obj.rid if not_use_index else obj.rid[index] input_text = obj.text if not_use_index else obj.text[index] input_ids = ( @@ -182,14 +183,27 @@ class TokenizerManager: top_logprobs_num = ( obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index] ) - else: - if isinstance(obj.text, list): - input_text = obj.text[index] - rid = obj.rid[index] + else: # A prefill request to cache the common prompt for parallel sampling + if obj.text is not None: + if isinstance(obj.text, list): + input_text = obj.text[index] + rid = obj.rid[index] + else: + input_text = obj.text + rid = obj.rid[0] + input_ids = self.tokenizer.encode(input_text) else: - input_text = obj.text - rid = obj.rid[0] - input_ids = self.tokenizer.encode(input_text) + input_text = None + if isinstance(obj.input_ids, list) and isinstance( + obj.input_ids[0], list + ): + # when obj["input_ids"] is List[List[int]] + input_ids = obj.input_ids[index] + rid = obj.rid[index] + else: + input_ids = obj.input_ids + rid = obj.rid[0] + sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 pixel_values, image_hash, image_size = await self._get_pixel_values( @@ -240,11 +254,11 @@ class TokenizerManager: ): if input_id_result is not None: input_id_result.append(input_id) - pass - if len(input_id_result) > 1 and input_id_result is not None: + if input_id_result is not None and len(input_id_result) > 1: obj.input_ids = input_id_result elif input_id_result is not None: obj.input_ids = input_id_result[0] + # First send out all requests for i in range(batch_size): for j in range(parallel_sample_num): @@ -264,11 +278,12 @@ class TokenizerManager: input_text = None input_ids = obj.input_ids[i] else: + assert obj.input_ids is not None if batch_size == 1: - input_text = obj.text + input_text = None input_ids = obj.input_ids else: - input_text = obj.text[i] + input_text = None input_ids = obj.input_ids[i] sampling_params = self._get_sampling_params(obj.sampling_params[index]) pixel_values, image_hash, image_size = await self._get_pixel_values( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 8ad028b1d..b51c12816 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe if end_point == "/v1/chat/completions": responses = v1_chat_generate_response(request, ret, to_file=True) else: - responses = v1_generate_response(request, ret, to_file=True) + responses = v1_generate_response( + request, ret, tokenizer_manager, to_file=True + ) except Exception as e: error_json = { @@ -339,6 +341,7 @@ def v1_generate_request(all_requests): return_logprobs = [] top_logprobs_nums = [] first_prompt_type = type(all_requests[0].prompt) + for request in all_requests: prompt = request.prompt assert ( @@ -364,7 +367,7 @@ def v1_generate_request(all_requests): ) if len(all_requests) > 1 and request.n > 1: raise ValueError( - "Batch operation is not supported for completions from files" + "Parallel sampling is not supported for completions from files" ) if len(all_requests) == 1: @@ -377,10 +380,11 @@ def v1_generate_request(all_requests): else: prompt_kwargs = {"input_ids": prompt} else: - if isinstance(prompts[0], str): + if isinstance(prompts[0], str) or isinstance(propmt[0][0], str): prompt_kwargs = {"text": prompts} else: prompt_kwargs = {"input_ids": prompts} + adapted_request = GenerateReqInput( **prompt_kwargs, sampling_params=sampling_params_list, @@ -389,35 +393,52 @@ def v1_generate_request(all_requests): return_text_in_logprobs=True, stream=all_requests[0].stream, ) + if len(all_requests) == 1: return adapted_request, all_requests[0] return adapted_request, all_requests -def v1_generate_response(request, ret, to_file=False): +def v1_generate_response(request, ret, tokenizer_manager, to_file=False): choices = [] echo = False - if (not isinstance(request, List)) and request.echo: + if (not isinstance(request, list)) and request.echo: # TODO: handle the case propmt is token ids - if isinstance(request.prompt, list): + if isinstance(request.prompt, list) and isinstance(request.prompt[0], str): + # for the case of multiple str prompts prompts = request.prompt + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): + # for the case of multiple token ids prompts + prompts = [ + tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True) + for prompt in request.prompt + ] + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): + # for the case of single token ids prompt + prompts = [ + tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + ] else: + # for the case of single str prompt prompts = [request.prompt] echo = True for idx, ret_item in enumerate(ret): text = ret_item["text"] - if isinstance(request, List) and request[idx].echo: + if isinstance(request, list) and request[idx].echo: echo = True text = request[idx].prompt + text - if (not isinstance(request, List)) and echo: - text = prompts[idx] + text + if (not isinstance(request, list)) and echo: + prompt_index = idx // request.n + text = prompts[prompt_index] + text logprobs = False - if isinstance(request, List) and request[idx].logprobs: + if isinstance(request, list) and request[idx].logprobs: logprobs = True - elif (not isinstance(request, List)) and request.logprobs: + elif (not isinstance(request, list)) and request.logprobs: logprobs = True if logprobs: if echo: @@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False): responses.append(response) return responses else: + prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret) completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) response = CompletionResponse( id=ret[0]["meta_info"]["id"], model=request.model, choices=choices, usage=UsageInfo( - prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens, + total_tokens=prompt_tokens + completion_tokens, ), ) return response @@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request): if not stream_buffer: # The first chunk if request.echo: + if isinstance(request.prompt, str): + # for the case of single str prompts + prompts = request.prompt + elif isinstance(request.prompt, list) and isinstance( + request.prompt[0], int + ): + prompts = tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + # Prepend prompt in response text. - text = request.prompt + text + text = prompts + text if request.logprobs: # The first chunk and echo is enabled. @@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): "output_top_logprobs" ][n_prev_token:], ) - n_prev_token = len( content["meta_info"]["output_token_logprobs"] ) @@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - response = v1_generate_response(request, ret) + response = v1_generate_response(request, ret, tokenizer_manager) return response @@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): prompt_ids = tokenizer_manager.tokenizer.encode(prompt) else: # Use the raw prompt and stop strings if the messages is already a string. - prompt = request.messages + prompt_ids = request.messages stop = request.stop image_data = None input_ids.append(prompt_ids) @@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): image_data_list.append(image_data) if len(all_requests) == 1: input_ids = input_ids[0] + if isinstance(input_ids, str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} sampling_params_list = sampling_params_list[0] image_data = image_data_list[0] return_logprobs = return_logprobs[0] top_logprobs_nums = top_logprobs_nums[0] + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} adapted_request = GenerateReqInput( - input_ids=input_ids, + **prompt_kwargs, image_data=image_data, sampling_params=sampling_params_list, return_logprob=return_logprobs, @@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False): for idx, ret_item in enumerate(ret): logprobs = False - if isinstance(request, List) and request[idx].logprobs: + if isinstance(request, list) and request[idx].logprobs: logprobs = True - elif (not isinstance(request, List)) and request.logprobs: + elif (not isinstance(request, list)) and request.logprobs: logprobs = True if logprobs: logprobs = to_openai_style_logprobs( @@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): is_first = True stream_buffer = "" + n_prev_token = 0 try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): + prompt_tokens = content["meta_info"]["prompt_tokens"] + completion_tokens = content["meta_info"]["completion_tokens"] + if request.logprobs: + logprobs = to_openai_style_logprobs( + output_token_logprobs=content["meta_info"][ + "output_token_logprobs" + ][n_prev_token:], + output_top_logprobs=content["meta_info"][ + "output_top_logprobs" + ][n_prev_token:], + ) + + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + token_logprobs = [] + for token, logprob in zip( + logprobs.tokens, logprobs.token_logprobs + ): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[ + 0 + ].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + choice_logprobs = ChoiceLogprobs(content=token_logprobs) + + else: + choice_logprobs = None + if is_first: # First chunk with role is_first = False @@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): index=0, delta=DeltaMessage(role="assistant"), finish_reason=content["meta_info"]["finish_reason"], + logprobs=choice_logprobs, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], choices=[choice_data], model=request.model, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), ) yield f"data: {chunk.model_dump_json()}\n\n" @@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): index=0, delta=DeltaMessage(content=delta), finish_reason=content["meta_info"]["finish_reason"], + logprobs=choice_logprobs, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], choices=[choice_data], model=request.model, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), ) yield f"data: {chunk.model_dump_json()}\n\n" except ValueError as e: diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index c1d2a8cf3..8c079dd2a 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -278,7 +278,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - logprobs: Optional[LogProbs] = None + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None finish_reason: Optional[str] = None diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 5e37b1b4d..45648ce1d 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -3,6 +3,7 @@ import unittest import openai +from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server @@ -18,60 +19,85 @@ class TestOpenAIServer(unittest.TestCase): cls.model, cls.base_url, timeout=300, api_key=cls.api_key ) cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(MODEL_NAME_FOR_TEST) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) - def run_completion(self, echo, logprobs, use_list_input): + def run_completion( + self, echo, logprobs, use_list_input, parallel_sample_num, token_input + ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" + if token_input: + prompt_input = self.tokenizer.encode(prompt) + num_prompt_tokens = len(prompt_input) + else: + prompt_input = prompt + num_prompt_tokens = len(self.tokenizer.encode(prompt)) if use_list_input: - prompt_arg = [prompt, prompt] + prompt_arg = [prompt_input, prompt_input] num_choices = len(prompt_arg) + num_prompt_tokens *= 2 else: - prompt_arg = prompt + prompt_arg = prompt_input num_choices = 1 + if parallel_sample_num: + # FIXME: This is wrong. We should not count the prompt tokens multiple times for + # parallel sampling. + num_prompt_tokens *= parallel_sample_num + response = client.completions.create( model=self.model, prompt=prompt_arg, - temperature=0.1, + temperature=0, max_tokens=32, echo=echo, logprobs=logprobs, + n=parallel_sample_num, ) - assert len(response.choices) == num_choices + assert len(response.choices) == num_choices * parallel_sample_num if echo: text = response.choices[0].text assert text.startswith(prompt) + if logprobs: assert response.choices[0].logprobs assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) - # FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value. + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" + assert ret_num_top_logprobs > 0 if echo: assert response.choices[0].logprobs.token_logprobs[0] == None else: assert response.choices[0].logprobs.token_logprobs[0] != None + assert response.id assert response.created - assert response.usage.prompt_tokens > 0 + assert ( + response.usage.prompt_tokens == num_prompt_tokens + ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}" assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_completion_stream(self, echo, logprobs): + def run_completion_stream(self, echo, logprobs, token_input): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" + if token_input: + prompt_arg = self.tokenizer.encode(prompt) + else: + prompt_arg = prompt generator = client.completions.create( model=self.model, - prompt=prompt, - temperature=0.1, + prompt=prompt_arg, + temperature=0, max_tokens=32, echo=echo, logprobs=logprobs, @@ -90,12 +116,15 @@ class TestOpenAIServer(unittest.TestCase): ret_num_top_logprobs = len( response.choices[0].logprobs.top_logprobs[0] ) - # FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value. + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" + assert ret_num_top_logprobs > 0 if first: if echo: - assert response.choices[0].text.startswith(prompt) + assert response.choices[0].text.startswith( + prompt + ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}" first = False assert response.id @@ -104,7 +133,7 @@ class TestOpenAIServer(unittest.TestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_chat_completion(self, logprobs): + def run_chat_completion(self, logprobs, parallel_sample_num): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( model=self.model, @@ -116,6 +145,7 @@ class TestOpenAIServer(unittest.TestCase): max_tokens=32, logprobs=logprobs is not None and logprobs > 0, top_logprobs=logprobs, + n=parallel_sample_num, ) if logprobs: assert isinstance( @@ -128,7 +158,7 @@ class TestOpenAIServer(unittest.TestCase): assert ( ret_num_top_logprobs == logprobs ), f"{ret_num_top_logprobs} vs {logprobs}" - + assert len(response.choices) == parallel_sample_num assert response.choices[0].message.role == "assistant" assert isinstance(response.choices[0].message.content, str) assert response.id @@ -161,11 +191,21 @@ class TestOpenAIServer(unittest.TestCase): continue if logprobs: - # FIXME: Fix this bug. Return top_logprobs in the streaming mode. - pass + assert response.choices[0].logprobs + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs[0].token, str + ) + assert isinstance( + response.choices[0].logprobs.content[0].top_logprobs, list + ) + ret_num_top_logprobs = len( + response.choices[0].logprobs.content[0].top_logprobs + ) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" assert isinstance(data.content, str) - assert response.id assert response.created @@ -173,16 +213,27 @@ class TestOpenAIServer(unittest.TestCase): for echo in [False, True]: for logprobs in [None, 5]: for use_list_input in [True, False]: - self.run_completion(echo, logprobs, use_list_input) + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + ) def test_completion_stream(self): + # parallel sampling adn list input are not supported in streaming mode for echo in [False, True]: for logprobs in [None, 5]: - self.run_completion_stream(echo, logprobs) + for token_input in [False, True]: + self.run_completion_stream(echo, logprobs, token_input) def test_chat_completion(self): for logprobs in [None, 5]: - self.run_chat_completion(logprobs) + for parallel_sample_num in [1, 2]: + self.run_chat_completion(logprobs, parallel_sample_num) def test_chat_completion_stream(self): for logprobs in [None, 5]: @@ -224,5 +275,5 @@ if __name__ == "__main__": # t = TestOpenAIServer() # t.setUpClass() - # t.test_chat_completion_stream() + # t.test_completion() # t.tearDownClass()