diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index edbfff3ec..e157217e3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -153,9 +153,6 @@ class TokenizerManager: async for response in self._handle_single_request(obj, request): yield response else: - if hasattr(obj, "stream") and obj.stream: - raise ValueError("Do not support stream for batch mode.") - async for response in self._handle_batch_request(obj, request): yield response @@ -311,6 +308,7 @@ class TokenizerManager: parallel_sample_num = 1 # First send out all requests + generators = [] for i in range(batch_size): for j in range(parallel_sample_num): if j == 0 and parallel_sample_num != 1: @@ -371,42 +369,48 @@ class TokenizerManager: state = ReqState([], False, event) self.rid_to_state[rid] = state - # Then wait for all responses - output_list = [] - for i in range(batch_size): - for j in range(parallel_sample_num): - if j == 0 and parallel_sample_num != 1: - continue - index = i * parallel_sample_num + j - if parallel_sample_num != 1: - index += batch_size - 1 - i - rid = obj.rid[index] - state = self.rid_to_state[rid] - - while True: - try: - await asyncio.wait_for(state.event.wait(), timeout=4) - break - except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): - for rid in obj.rid: - self.abort_request(rid) - raise ValueError(f"Abort request {rid}") - continue - if self.is_generation: - output_list.append( - self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob[index], - obj.top_logprobs_num[index], - obj.return_text_in_logprobs, - ) + generators.append( + self._wait_for_response( + event, + state, + obj, + rid, + request, + index=index, + response_index=len(generators), ) - else: - output_list.append(state.out_list[-1]) - assert state.finished - del self.rid_to_state[rid] - yield output_list + ) + + # Then process the responses based on streaming option + + is_stream = hasattr(obj, "stream") and obj.stream + + tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] + output_list = [] + + while tasks: + done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + gen_index = tasks.index(task) + + try: + result = task.result() + + if is_stream: + yield result + else: + output_list.append(result) + + tasks[gen_index] = asyncio.create_task( + generators[gen_index].__anext__() + ) + except StopAsyncIteration: + del generators[gen_index] + del tasks[gen_index] + + if not is_stream: + yield output_list def _validate_input_length(self, input_ids: List[int]): if len(input_ids) >= self.context_len: @@ -437,26 +441,35 @@ class TokenizerManager: obj: Union[GenerateReqInput, EmbeddingReqInput], rid: str, request, + index: int = None, + response_index: int = 0, ): while True: try: await asyncio.wait_for(event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): - self.abort_request(rid) + for rid in [obj.rid] if obj.is_single else obj.rid: + self.abort_request(rid) raise ValueError(f"Abort request {rid}") continue if self.is_generation: out = self.convert_logprob_style( state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, + obj.return_logprob if index is None else obj.return_logprob[index], + ( + obj.top_logprobs_num + if index is None + else obj.top_logprobs_num[index] + ), obj.return_text_in_logprobs, ) else: # isinstance(obj, EmbeddingReqInput) out = state.out_list[-1] + out["index"] = response_index + # Log requests if self.server_args.log_requests and state.finished: if obj.text is None: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5d7bb7af7..12b40d6c4 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe request_data = json.loads(line) file_request_list.append(request_data) body = request_data["body"] + + # Although streaming is supported for standalone completions, it is not supported in + # batch mode (multiple completions in single request). + if body.get("stream", False): + raise ValueError("Streaming requests are not supported in batch mode") + if end_point == "/v1/chat/completions": all_requests.append(ChatCompletionRequest(**body)) elif end_point == "/v1/completions": @@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request): if adapted_request.stream: async def generate_stream_resp(): - stream_buffer = "" - n_prev_token = 0 + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): + index = content["index"] + + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + text = content["text"] - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] if not stream_buffer: # The first chunk if request.echo: if isinstance(request.prompt, str): # for the case of single str prompts prompts = request.prompt - elif isinstance(request.prompt, list) and isinstance( - request.prompt[0], int - ): - prompts = tokenizer_manager.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) + elif isinstance(request.prompt, list): + if isinstance(request.prompt[0], str): + # for the case of multiple str prompts + prompts = request.prompt[index // request.n] + elif isinstance(request.prompt[0], int): + # for the case of single token ids prompt + prompts = tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + elif isinstance(request.prompt[0], list) and isinstance( + request.prompt[0][0], int + ): + # for the case of multiple token ids prompts + prompts = tokenizer_manager.tokenizer.decode( + request.prompt[index // request.n], + skip_special_tokens=True, + ) # Prepend prompt in response text. text = prompts + text @@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): delta = text[len(stream_buffer) :] stream_buffer = stream_buffer + delta choice_data = CompletionResponseStreamChoice( - index=0, + index=index, text=delta, logprobs=logprobs, finish_reason=format_finish_reason( @@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request): choices=[choice_data], model=request.model, ) + + stream_buffers[index] = stream_buffer + n_prev_tokens[index] = n_prev_token + yield f"data: {chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: + total_prompt_tokens = sum( + tokens + for i, tokens in prompt_tokens.items() + if i % request.n == 0 + ) + total_completion_tokens = sum( + tokens for tokens in completion_tokens.values() + ) usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, ) final_usage_chunk = CompletionStreamResponse( @@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if adapted_request.stream: async def generate_stream_resp(): - is_first = True - - stream_buffer = "" - n_prev_token = 0 + is_firsts = {} + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] + index = content["index"] + + is_first = is_firsts.get(index, True) + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] if request.logprobs: logprobs = to_openai_style_logprobs( output_token_logprobs=content["meta_info"][ @@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): # First chunk with role is_first = False choice_data = ChatCompletionResponseStreamChoice( - index=0, + index=index, delta=DeltaMessage(role="assistant"), finish_reason=format_finish_reason( content["meta_info"]["finish_reason"] @@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): delta = text[len(stream_buffer) :] stream_buffer = stream_buffer + delta choice_data = ChatCompletionResponseStreamChoice( - index=0, + index=index, delta=DeltaMessage(content=delta), finish_reason=format_finish_reason( content["meta_info"]["finish_reason"] @@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choices=[choice_data], model=request.model, ) + + is_firsts[index] = is_first + stream_buffers[index] = stream_buffer + n_prev_tokens[index] = n_prev_token + yield f"data: {chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: + total_prompt_tokens = sum( + tokens + for i, tokens in prompt_tokens.items() + if i % request.n == 0 + ) + total_completion_tokens = sum( + tokens for tokens in completion_tokens.values() + ) usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index c62fefe9f..828f5ab53 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_completion_stream(self, echo, logprobs, token_input): + def run_completion_stream( + self, echo, logprobs, use_list_input, parallel_sample_num, token_input + ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" if token_input: - prompt_arg = self.tokenizer.encode(prompt) + prompt_input = self.tokenizer.encode(prompt) + num_prompt_tokens = len(prompt_input) else: - prompt_arg = prompt + prompt_input = prompt + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + + if use_list_input: + prompt_arg = [prompt_input, prompt_input] + num_choices = len(prompt_arg) + num_prompt_tokens *= 2 + else: + prompt_arg = prompt_input + num_choices = 1 + generator = client.completions.create( model=self.model, prompt=prompt_arg, @@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase): logprobs=logprobs, stream=True, stream_options={"include_usage": True}, + n=parallel_sample_num, ) - first = True + is_firsts = {} for response in generator: usage = response.usage if usage is not None: @@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase): assert usage.completion_tokens > 0 assert usage.total_tokens > 0 continue + + index = response.choices[0].index + is_first = is_firsts.get(index, True) + if logprobs: assert response.choices[0].logprobs assert isinstance(response.choices[0].logprobs.tokens[0], str) - if not (first and echo): + if not (is_first and echo): assert isinstance( response.choices[0].logprobs.top_logprobs[0], dict ) @@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase): # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" assert ret_num_top_logprobs > 0 - if first: + if is_first: if echo: assert response.choices[0].text.startswith( prompt - ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}" - first = False + ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" + is_firsts[index] = False assert response.id assert response.created + for index in [i for i in range(parallel_sample_num * num_choices)]: + assert not is_firsts.get( + index, True + ), f"index {index} is not found in the response" + def run_chat_completion(self, logprobs, parallel_sample_num): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( @@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_chat_completion_stream(self, logprobs): + def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): client = openai.Client(api_key=self.api_key, base_url=self.base_url) generator = client.chat.completions.create( model=self.model, @@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase): top_logprobs=logprobs, stream=True, stream_options={"include_usage": True}, + n=parallel_sample_num, ) - is_first = True + is_firsts = {} for response in generator: usage = response.usage if usage is not None: @@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase): assert usage.total_tokens > 0 continue + index = response.choices[0].index data = response.choices[0].delta - if is_first: - data.role == "assistant" - is_first = False + if is_firsts.get(index, True): + assert data.role == "assistant" + is_firsts[index] = False continue if logprobs: @@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase): assert response.id assert response.created + for index in [i for i in range(parallel_sample_num)]: + assert not is_firsts.get( + index, True + ), f"index {index} is not found in the response" + def run_batch(self, mode): client = openai.Client(api_key=self.api_key, base_url=self.base_url) if mode == "completion": @@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase): f"Batch job status: {batch_job.status}...trying again in 3 seconds..." ) batch_job = client.batches.retrieve(batch_job.id) - assert batch_job.status == "completed" + assert ( + batch_job.status == "completed" + ), f"Batch job status is not completed: {batch_job.status}" assert batch_job.request_counts.completed == len(content) assert batch_job.request_counts.failed == 0 assert batch_job.request_counts.total == len(content) @@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase): # parallel sampling adn list input are not supported in streaming mode for echo in [False, True]: for logprobs in [None, 5]: - for token_input in [False, True]: - self.run_completion_stream(echo, logprobs, token_input) + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion_stream( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + ) def test_chat_completion(self): for logprobs in [None, 5]: @@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase): def test_chat_completion_stream(self): for logprobs in [None, 5]: - self.run_chat_completion_stream(logprobs) + for parallel_sample_num in [1, 2]: + self.run_chat_completion_stream(logprobs, parallel_sample_num) def test_batch(self): for mode in ["completion", "chat"]: diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 5e6bcbf60..60f4cd58a 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase): kill_child_process(cls.process.pid) def run_decode( - self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 + self, + return_logprob=False, + top_logprobs_num=0, + return_text=False, + n=1, + stream=False, ): response = requests.post( self.base_url + "/generate", @@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase): "max_new_tokens": 32, "n": n, }, - "stream": False, + "stream": stream, "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "return_text_in_logprobs": return_text, "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) + if not stream: + response_json = response.json() + else: + response_json = [] + for line in response.iter_lines(): + if line.startswith(b"data: ") and line[6:] != b"[DONE]": + response_json.append(json.loads(line[6:])) + print(json.dumps(response_json)) print("=" * 100) def test_simple_decode(self): @@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase): def test_parallel_sample(self): self.run_decode(n=3) + def test_parallel_sample_stream(self): + self.run_decode(n=3, stream=True) + def test_logprob(self): for top_logprobs_num in [0, 3]: for return_text in [True, False]: