From 49c5e0eca9fe0193e716d0a51bdc2ec7c90a0184 Mon Sep 17 00:00:00 2001 From: yichuan~ <73766326+yichuan520030910320@users.noreply.github.com> Date: Sat, 20 Jul 2024 14:10:01 +0800 Subject: [PATCH] Add support for OpenAI API parallel sampling (#640) --- examples/usage/openai_parallel_sample.py | 75 ++++ python/sglang/srt/managers/io_struct.py | 27 +- .../sglang/srt/managers/tokenizer_manager.py | 329 ++++++++++++------ python/sglang/srt/openai_api_adapter.py | 114 +++--- python/sglang/srt/sampling_params.py | 2 + 5 files changed, 380 insertions(+), 167 deletions(-) create mode 100644 examples/usage/openai_parallel_sample.py diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py new file mode 100644 index 000000000..d2d1e406f --- /dev/null +++ b/examples/usage/openai_parallel_sample.py @@ -0,0 +1,75 @@ +import openai + +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=1, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=3, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is ", "The capital of US is"], + n=1, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is ", "The capital of US is"], + n=3, + temperature=0.8, + max_tokens=32, +) +print(response) + + +# Text completion +response = client.completions.create( + model="default", + prompt=[ + "The capital of France is", + "The capital of Germany is", + "The capital of US is", + ], + n=3, + temperature=0.8, + max_tokens=32, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=64, + logprobs=True, + n=4, +) +print(response) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 89de9b1c3..2ba2095c7 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -40,11 +40,13 @@ class GenerateReqInput: self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") - - if self.text is not None: - is_single = isinstance(self.text, str) + if "n" in self.sampling_params and self.sampling_params["n"] != 1: + is_single = False else: - is_single = isinstance(self.input_ids[0], int) + if self.text is not None: + is_single = isinstance(self.text, str) + else: + is_single = isinstance(self.input_ids[0], int) self.is_single = is_single if is_single: @@ -59,7 +61,22 @@ class GenerateReqInput: if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: - num = len(self.text) if self.text is not None else len(self.input_ids) + + parallel_sample_num = self.sampling_params.get("n", 1) + + if parallel_sample_num != 1: + # parallel sampling +1 represents the original prefill stage + num = parallel_sample_num + 1 + if isinstance(self.text, List): + ## suppot batch operation + self.batch_size = len(self.text) + num = num * len(self.text) + else: + self.batch_size = 1 + else: + ## support select operation + num = len(self.text) if self.text is not None else len(self.input_ids) + self.batch_size = num if self.image_data is None: self.image_data = [None] * num diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d5a81bfaa..0eabb480a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -122,125 +122,150 @@ class TokenizerManager: obj.post_init() is_single = obj.is_single + if is_single: - rid = obj.rid - - if obj.input_ids is None: - input_ids = self.tokenizer.encode(obj.text) - else: - input_ids = obj.input_ids - - if len(input_ids) >= self.context_len: - raise ValueError( - f"The input ({len(input_ids)} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." - ) - - sampling_params = SamplingParams(**obj.sampling_params) - if sampling_params.max_new_tokens != 0: - sampling_params.normalize(self.tokenizer) - sampling_params.verify() - - if isinstance(obj.image_data, list) and len(obj.image_data) > 0: - pixel_values, image_hash, image_size = await self.get_pixel_values( - obj.image_data[0] - ) - elif isinstance(obj.image_data, str): - pixel_values, image_hash, image_size = await self.get_pixel_values( - obj.image_data - ) - else: - pixel_values, image_hash, image_size = None, None, None - tokenized_obj = TokenizedGenerateReqInput( - rid=rid, - input_text=obj.text, - input_ids=input_ids, - pixel_values=pixel_values, - image_hash=image_hash, - image_size=image_size, - sampling_params=sampling_params, - return_logprob=obj.return_logprob, - logprob_start_len=obj.logprob_start_len, - top_logprobs_num=obj.top_logprobs_num, - stream=obj.stream, - ) - self.send_to_router.send_pyobj(tokenized_obj) - - event = asyncio.Event() - state = ReqState([], False, event) - self.rid_to_state[rid] = state - - while True: - try: - await asyncio.wait_for(event.wait(), timeout=4) - except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): - self.abort_request(rid) - raise ValueError(f"Abort request {rid}") - continue - - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, - ) - - if self.server_args.log_requests and state.finished: - logger.info(f"in={obj.text}, out={out}") - - state.out_list = [] - if state.finished: - del self.rid_to_state[rid] - - yield out - - break - - event.clear() - - yield out + async for response in self._handle_single_request(obj, request): + yield response else: if obj.stream: raise ValueError("Do not support stream for batch mode.") - if obj.input_ids is None: - bs = len(obj.text) + async for response in self._handle_batch_request(obj, request): + yield response + + async def _handle_single_request(self, obj, request, index=None, is_prefill=False): + if is_prefill: + if isinstance(obj.text, list): + input_text = obj.text[index] + rid = obj.rid[index] else: - bs = len(obj.input_ids) + input_text = obj.text + rid = obj.rid[0] + input_ids = self.tokenizer.encode(input_text) + sampling_params = SamplingParams(**obj.sampling_params[0]) + sampling_params.max_new_tokens = 0 + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data[0] + ) + return_logprob = obj.return_logprob[0] + logprob_start_len = obj.logprob_start_len[0] + top_logprobs_num = obj.top_logprobs_num[0] + else: + rid = obj.rid if index is None else obj.rid[index] + input_text = obj.text if index is None else obj.text[index] + input_ids = ( + self.tokenizer.encode(input_text) + if obj.input_ids is None + else obj.input_ids + ) + if index is not None and obj.input_ids: + input_ids = obj.input_ids[index] - for i in range(bs): - rid = obj.rid[i] + self._validate_input_length(input_ids) + sampling_params = self._get_sampling_params( + obj.sampling_params if index is None else obj.sampling_params[index] + ) + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data if index is None else obj.image_data[index] + ) + return_logprob = ( + obj.return_logprob if index is None else obj.return_logprob[index] + ) + logprob_start_len = ( + obj.logprob_start_len if index is None else obj.logprob_start_len[index] + ) + top_logprobs_num = ( + obj.top_logprobs_num if index is None else obj.top_logprobs_num[index] + ) - if obj.input_ids is None: - input_text = obj.text[i] - input_ids = self.tokenizer.encode(obj.text[i]) + tokenized_obj = TokenizedGenerateReqInput( + rid, + input_text, + input_ids, + pixel_values, + image_hash, + image_size, + sampling_params, + return_logprob, + logprob_start_len, + top_logprobs_num, + obj.stream, + ) + self.send_to_router.send_pyobj(tokenized_obj) + + event = asyncio.Event() + state = ReqState([], False, event) + self.rid_to_state[rid] = state + if is_prefill == False: + async for response in self._wait_for_response( + event, state, obj, rid, request + ): + yield response + else: + await self._wait_for_prefill_response(event, state, obj, request, rid) + yield input_ids + + async def _handle_batch_request(self, obj, request): + batch_size = obj.batch_size + parallel_sample_num = obj.sampling_params[0].get("n", 1) + + if parallel_sample_num != 1: + ## send prefill requests + parallel_sample_num += 1 + input_id_result = [] if obj.input_ids is None else None + for i in range(batch_size): + async for input_id in self._handle_single_request( + obj, request, index=i, is_prefill=True + ): + if input_id_result is not None: + input_id_result.append(input_id) + pass + if len(input_id_result) > 1 and input_id_result is not None: + obj.input_ids = input_id_result + elif input_id_result is not None: + obj.input_ids = input_id_result[0] + # First send out all requests + for i in range(batch_size): + for j in range(parallel_sample_num): + if j == 0 and parallel_sample_num != 1: + continue + index = i * parallel_sample_num + j + if parallel_sample_num != 1: + # Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 + index += batch_size - 1 - i + rid = obj.rid[index] + if parallel_sample_num == 1: + ## select operation + if obj.input_ids is None: + input_text = obj.text[i] + input_ids = self.tokenizer.encode(obj.text[i]) + else: + input_text = None + input_ids = obj.input_ids[i] else: - input_text = None - input_ids = obj.input_ids[i] + if batch_size == 1: + input_text = obj.text + input_ids = obj.input_ids + else: + input_text = obj.text[i] + input_ids = obj.input_ids[i] + sampling_params = self._get_sampling_params(obj.sampling_params[index]) + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data[index] + ) - sampling_params = SamplingParams(**obj.sampling_params[i]) - if sampling_params.max_new_tokens != 0: - sampling_params.normalize(self.tokenizer) - sampling_params.verify() - if obj.image_data[i] is None: - pixel_values, image_hash, image_size = None, None, None - else: - pixel_values, image_hash, image_size = await self.get_pixel_values( - obj.image_data[i] - ) tokenized_obj = TokenizedGenerateReqInput( - rid=rid, - input_text=input_text, - input_ids=input_ids, - pixel_values=pixel_values, - image_hash=image_hash, - image_size=image_size, - sampling_params=sampling_params, - return_logprob=obj.return_logprob[i], - logprob_start_len=obj.logprob_start_len[i], - top_logprobs_num=obj.top_logprobs_num[i], - stream=obj.stream, + rid, + input_text, + input_ids, + pixel_values, + image_hash, + image_size, + sampling_params, + obj.return_logprob[index], + obj.logprob_start_len[index], + obj.top_logprobs_num[index], + obj.stream, ) self.send_to_router.send_pyobj(tokenized_obj) @@ -248,9 +273,16 @@ class TokenizerManager: state = ReqState([], False, event) self.rid_to_state[rid] = state - output_list = [] - for i in range(bs): - rid = obj.rid[i] + # Then wait for all responses + output_list = [] + for i in range(batch_size): + for j in range(parallel_sample_num): + if j == 0 and parallel_sample_num != 1: + continue + index = i * parallel_sample_num + j + if parallel_sample_num != 1: + index += batch_size - 1 - i + rid = obj.rid[index] state = self.rid_to_state[rid] while True: @@ -263,19 +295,86 @@ class TokenizerManager: self.abort_request(rid) raise ValueError(f"Abort request {rid}") continue - output_list.append( self.convert_logprob_style( state.out_list[-1], - obj.return_logprob[i], - obj.top_logprobs_num[i], + obj.return_logprob[index], + obj.top_logprobs_num[index], obj.return_text_in_logprobs, ) ) assert state.finished del self.rid_to_state[rid] - yield output_list + yield output_list + + def _validate_input_length(self, input_ids): + if len(input_ids) >= self.context_len: + raise ValueError( + f"The input ({len(input_ids)} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)." + ) + + def _get_sampling_params(self, sampling_params_data, max_new_tokens=None): + sampling_params = SamplingParams(**sampling_params_data) + if max_new_tokens is not None: + sampling_params.max_new_tokens = max_new_tokens + if sampling_params.max_new_tokens != 0: + sampling_params.normalize(self.tokenizer) + sampling_params.verify() + return sampling_params + + async def _get_pixel_values(self, image_data): + if isinstance(image_data, list) and len(image_data) > 0: + return await self.get_pixel_values(image_data[0]) + elif isinstance(image_data, str): + return await self.get_pixel_values(image_data) + else: + return None, None, None + + async def _wait_for_response(self, event, state, obj, rid, request): + while True: + try: + await asyncio.wait_for(event.wait(), timeout=4) + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue + + out = self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob, + obj.top_logprobs_num, + obj.return_text_in_logprobs, + ) + + if self.server_args.log_requests and state.finished: + logger.info(f"in={obj.text}, out={out}") + + state.out_list = [] + if state.finished: + del self.rid_to_state[rid] + yield out + break + + event.clear() + yield out + + async def _wait_for_prefill_response(self, event, state, obj, request, rid): + while True: + try: + await asyncio.wait_for(state.event.wait(), timeout=4) + break + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + for rid in obj.rid: + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue + + assert state.finished + del self.rid_to_state[rid] def flush_cache(self): req = FlushCacheReq() diff --git a/python/sglang/srt/openai_api_adapter.py b/python/sglang/srt/openai_api_adapter.py index 4306950f0..f1f09c919 100644 --- a/python/sglang/srt/openai_api_adapter.py +++ b/python/sglang/srt/openai_api_adapter.py @@ -95,9 +95,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() request = CompletionRequest(**request_json) - if request.n != 1: - return create_error_response("n != 1 is not supported") - adapted_request = GenerateReqInput( text=request.prompt, sampling_params={ @@ -108,6 +105,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, "regex": request.regex, + "n": request.n, }, return_logprob=request.logprobs is not None and request.logprobs > 0, top_logprobs_num=request.logprobs if request.logprobs is not None else 0, @@ -202,46 +200,56 @@ async def v1_completions(tokenizer_manager, raw_request: Request): except ValueError as e: return create_error_response(str(e)) - ret = ret[0] if isinstance(ret, list) else ret - prompt_tokens = ret["meta_info"]["prompt_tokens"] - completion_tokens = ret["meta_info"]["completion_tokens"] - text = ret["text"] - if request.echo: - text = request.prompt + text + if not isinstance(ret, list): + ret = [ret] + choices = [] + + for idx, ret_item in enumerate(ret): + text = ret_item["text"] - if request.logprobs: if request.echo: - prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"] - prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"] + text = request.prompt + text + + if request.logprobs: + if request.echo: + prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] + prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] + else: + prefill_token_logprobs = None + prefill_top_logprobs = None + + logprobs = to_openai_style_logprobs( + prefill_token_logprobs=prefill_token_logprobs, + prefill_top_logprobs=prefill_top_logprobs, + decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"], + decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"], + ) else: - prefill_token_logprobs = None - prefill_top_logprobs = None + logprobs = None - logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"], - decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"], + choice_data = CompletionResponseChoice( + index=idx, + text=text, + logprobs=logprobs, + finish_reason=ret_item["meta_info"]["finish_reason"], ) - else: - logprobs = None - choice_data = CompletionResponseChoice( - index=0, - text=text, - logprobs=logprobs, - finish_reason=ret["meta_info"]["finish_reason"], - ) + choices.append(choice_data) + response = CompletionResponse( - id=ret["meta_info"]["id"], + id=ret[0]["meta_info"]["id"], model=request.model, - choices=[choice_data], + choices=choices, usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], + completion_tokens=sum( + item["meta_info"]["completion_tokens"] for item in ret + ), + total_tokens=ret[0]["meta_info"]["prompt_tokens"] + + sum(item["meta_info"]["completion_tokens"] for item in ret), ), ) + return response @@ -249,9 +257,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() request = ChatCompletionRequest(**request_json) - if request.n != 1: - return create_error_response("n != 1 is not supported") - # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. # - stop: Custom stop tokens. @@ -292,6 +297,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, "regex": request.regex, + "n": request.n, }, stream=request.stream, ) @@ -354,23 +360,37 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): except ValueError as e: return create_error_response(str(e)) - prompt_tokens = ret["meta_info"]["prompt_tokens"] - completion_tokens = ret["meta_info"]["completion_tokens"] - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=ret["text"]), - finish_reason=ret["meta_info"]["finish_reason"], - ) + if not isinstance(ret, list): + ret = [ret] + choices = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for idx, ret_item in enumerate(ret): + prompt_tokens = ret_item["meta_info"]["prompt_tokens"] + completion_tokens = ret_item["meta_info"]["completion_tokens"] + + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage(role="assistant", content=ret_item["text"]), + finish_reason=ret_item["meta_info"]["finish_reason"], + ) + + choices.append(choice_data) + total_prompt_tokens = prompt_tokens + total_completion_tokens += completion_tokens + response = ChatCompletionResponse( - id=ret["meta_info"]["id"], + id=ret[0]["meta_info"]["id"], model=request.model, - choices=[choice_data], + choices=choices, usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, ), ) + return response diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index f6b4f5706..76b802886 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -20,6 +20,7 @@ class SamplingParams: spaces_between_special_tokens: bool = True, dtype: Optional[str] = None, regex: Optional[str] = None, + n: int = 1, ) -> None: self.temperature = temperature self.top_p = top_p @@ -33,6 +34,7 @@ class SamplingParams: self.spaces_between_special_tokens = spaces_between_special_tokens self.dtype = dtype self.regex = regex + self.n = n # Process some special cases if self.temperature < _SAMPLING_EPS: