From c17c57810891591b3f7d5151d65b1e8d13af50f9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 3 Nov 2024 08:38:26 -0800 Subject: [PATCH] Simplify tokenizer manager (#1904) --- .../srt/managers/data_parallel_controller.py | 2 - python/sglang/srt/managers/io_struct.py | 150 +++---- python/sglang/srt/managers/scheduler.py | 7 +- .../sglang/srt/managers/tokenizer_manager.py | 395 +++++------------- python/sglang/srt/openai_api/adapter.py | 122 +++--- python/sglang/srt/server.py | 5 +- test/srt/run_suite.py | 2 +- test/srt/test_openai_server.py | 5 + test/srt/test_skip_tokenizer_init.py | 3 + test/srt/test_srt_endpoint.py | 12 +- test/srt/test_vision_openai_server.py | 1 + 11 files changed, 261 insertions(+), 443 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index dca3d4f01..132827966 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -24,7 +24,6 @@ import zmq from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - TokenizedRewardReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs @@ -152,7 +151,6 @@ class DataParallelController: ( TokenizedGenerateReqInput, TokenizedEmbeddingReqInput, - TokenizedRewardReqInput, ), ): self.dispatching(recv_req) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index df873035e..339638c0a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -56,49 +56,47 @@ class GenerateReqInput: # LoRA related lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - # Whether it is a single request or a batch request - is_single: bool = True - - def post_init(self): + def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") - self.is_single = False + # Derive the batch size if self.text is not None: if isinstance(self.text, str): self.is_single = True self.batch_size = 1 else: + self.is_single = False self.batch_size = len(self.text) else: if isinstance(self.input_ids[0], int): self.is_single = True self.batch_size = 1 else: + self.is_single = False self.batch_size = len(self.input_ids) + # Handle parallel sampling + # When parallel sampling is used, we always treat the input as a batch. if self.sampling_params is None: self.parallel_sample_num = 1 elif isinstance(self.sampling_params, dict): self.parallel_sample_num = self.sampling_params.get("n", 1) else: # isinstance(self.sampling_params, list): self.parallel_sample_num = self.sampling_params[0].get("n", 1) - for sp in self.sampling_params: - # TODO cope with the case that the parallel_sample_num is different for different samples - assert self.parallel_sample_num == sp.get( - "n", 1 - ), "The parallel_sample_num should be the same for all samples in sample params." + assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), ( + "The parallel_sample_num should be the same for all samples in sample params.") - if self.parallel_sample_num > 1: - if self.is_single: - self.is_single = False - if self.text is not None: - self.text = [self.text] - if self.input_ids is not None: - self.input_ids = [self.input_ids] + if self.parallel_sample_num > 1 and self.is_single: + self.is_single = False + if self.text is not None: + self.text = [self.text] + if self.input_ids is not None: + self.input_ids = [self.input_ids] + # Fill in default arguments if self.is_single: if self.sampling_params is None: self.sampling_params = {} @@ -114,8 +112,8 @@ class GenerateReqInput: if self.parallel_sample_num == 1: num = self.batch_size else: - # The first bs samples are used for caching the prefix for parallel sampling - num = self.batch_size + self.parallel_sample_num * self.batch_size + # Expand parallel_sample_num + num = self.batch_size * self.parallel_sample_num if self.image_data is None: self.image_data = [None] * num @@ -128,14 +126,11 @@ class GenerateReqInput: self.sampling_params = [{}] * num elif not isinstance(self.sampling_params, list): self.sampling_params = [self.sampling_params] * num - else: - assert self.parallel_sample_num == 1 if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(num)] else: assert isinstance(self.rid, list), "The rid should be a list." - assert self.parallel_sample_num == 1 if self.return_logprob is None: self.return_logprob = [False] * num @@ -158,6 +153,26 @@ class GenerateReqInput: else: assert self.parallel_sample_num == 1 + def regenerate_rid(self): + self.rid = uuid.uuid4().hex + return self.rid + + def __getitem__(self, i): + return GenerateReqInput( + text=self.text[i] if self.text is not None else None, + input_ids=self.input_ids[i] if self.input_ids is not None else None, + image_data=self.image_data[i], + sampling_params=self.sampling_params[i], + rid=self.rid[i], + return_logprob=self.return_logprob[i], + logprob_start_len=self.logprob_start_len[i], + top_logprobs_num=self.top_logprobs_num[i], + return_text_in_logprobs=self.return_text_in_logprobs, + stream=self.stream, + modalities=self.modalities[i] if self.modalities else None, + lora_path=self.lora_path[i] if self.lora_path is not None else None, + ) + @dataclass class TokenizedGenerateReqInput: @@ -195,20 +210,29 @@ class EmbeddingReqInput: # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None - # Whether it is a single request or a batch request - is_single: bool = True - - def post_init(self): + def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") + # Derive the batch size if self.text is not None: - self.is_single = isinstance(self.text, str) + if isinstance(self.text, str): + self.is_single = True + self.batch_size = 1 + else: + self.is_single = False + self.batch_size = len(self.text) else: - self.is_single = isinstance(self.input_ids[0], int) + if isinstance(self.input_ids[0], int): + self.is_single = True + self.batch_size = 1 + else: + self.is_single = False + self.batch_size = len(self.input_ids) + # Fill in default arguments if self.is_single: if self.rid is None: self.rid = uuid.uuid4().hex @@ -216,20 +240,28 @@ class EmbeddingReqInput: self.sampling_params = {} self.sampling_params["max_new_tokens"] = 1 else: - # support select operation - self.batch_size = ( - len(self.text) if self.text is not None else len(self.input_ids) - ) if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] else: - if not isinstance(self.rid, list): - raise ValueError("The rid should be a list.") + assert isinstance(self.rid, list), "The rid should be a list." + if self.sampling_params is None: self.sampling_params = [{}] * self.batch_size for i in range(self.batch_size): self.sampling_params[i]["max_new_tokens"] = 1 + def regenerate_rid(self): + self.rid = uuid.uuid4().hex + return self.rid + + def __getitem__(self, i): + return EmbeddingReqInput( + text=self.text[i] if self.text is not None else None, + input_ids=self.input_ids[i] if self.input_ids is not None else None, + sampling_params=self.sampling_params[i], + rid=self.rid[i], + ) + @dataclass class TokenizedEmbeddingReqInput: @@ -243,56 +275,6 @@ class TokenizedEmbeddingReqInput: sampling_params: SamplingParams -RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]] - - -@dataclass -class RewardReqInput: - # The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string. - conv: RewardReqConv - # The request id. - rid: Optional[Union[List[str], str]] = None - # Dummy sampling params for compatibility - sampling_params: Union[List[Dict], Dict] = None - - # Whether it is a single request or a batch request - is_single: bool = True - - def post_init(self): - self.is_single = isinstance(self.conv[0], dict) - - if self.is_single: - if self.rid is None: - self.rid = uuid.uuid4().hex - if self.sampling_params is None: - self.sampling_params = {} - self.sampling_params["max_new_tokens"] = 1 - else: - # support select operation - self.batch_size = len(self.conv) - if self.rid is None: - self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] - else: - if not isinstance(self.rid, list): - raise ValueError("The rid should be a list.") - if self.sampling_params is None: - self.sampling_params = [{}] * self.batch_size - for i in range(self.batch_size): - self.sampling_params[i]["max_new_tokens"] = 1 - - -@dataclass -class TokenizedRewardReqInput: - # The request id - rid: str - # The input text - input_text: str - # The input token ids - input_ids: List[int] - # Dummy sampling params for compatibility - sampling_params: SamplingParams - - @dataclass class BatchTokenIDOut: # The request id diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 74cd1f3ea..ea98aa696 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -43,7 +43,6 @@ from sglang.srt.managers.io_struct import ( ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - TokenizedRewardReqInput, UpdateWeightReqInput, UpdateWeightReqOutput, ) @@ -394,9 +393,7 @@ class Scheduler: for recv_req in recv_reqs: if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) - elif isinstance( - recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput) - ): + elif isinstance(recv_req, TokenizedEmbeddingReqInput): self.handle_embedding_request(recv_req) elif isinstance(recv_req, FlushCacheReq): self.flush_cache() @@ -487,7 +484,7 @@ class Scheduler: def handle_embedding_request( self, - recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput], + recv_req: TokenizedEmbeddingReqInput, ): req = Req( recv_req.rid, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c7d0bc783..1a9dc5e2b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -16,6 +16,7 @@ limitations under the License. """TokenizerManager is a process that tokenizes the text.""" import asyncio +import copy import dataclasses import json import logging @@ -51,11 +52,8 @@ from sglang.srt.managers.io_struct import ( GetMemPoolSizeReq, GetMemPoolSizeReqOutput, ProfileReq, - RewardReqConv, - RewardReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - TokenizedRewardReqInput, UpdateWeightReqInput, UpdateWeightReqOutput, ) @@ -157,7 +155,7 @@ class TokenizerManager: async def generate_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, ): if self.to_create_loop: @@ -172,122 +170,54 @@ class TokenizerManager: "Please add `--is-embedding` when launching the server or try another model." ) - obj.post_init() + obj.normalize_batch_and_arguments() is_single = obj.is_single if is_single: - async for response in self._handle_single_request(obj, request): + tokenized_obj = await self._tokenize_one_request(obj) + self.send_to_scheduler.send_pyobj(tokenized_obj) + async for response in self._wait_one_response(obj, request): yield response else: async for response in self._handle_batch_request(obj, request): yield response - async def _send_single_request( + async def _tokenize_one_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], - index: Optional[int] = None, - input_id_index: Optional[int] = None, - is_cache_for_prefill: Optional[bool] = False, + obj: Union[GenerateReqInput, EmbeddingReqInput], ): - if not is_cache_for_prefill: # The normal case with a single prompt - if index is None: - rid = obj.rid - if isinstance(obj, RewardReqInput): - input_text = self._apply_chat_template(obj.conv) - input_ids = self.tokenizer.encode(input_text) - elif obj.input_ids is None: - input_text = obj.text - input_ids = self.tokenizer.encode(input_text) - else: - input_text = obj.text if obj.text is not None else None - input_ids = obj.input_ids + """Tokenize one request.""" + # Tokenize + input_text = obj.text + if obj.input_ids is None: + input_ids = self.tokenizer.encode(input_text) + else: + input_ids = obj.input_ids - sampling_params = self._get_sampling_params(obj.sampling_params) - if self.is_generation: - image_inputs = await self.image_processor.process_images_async( - obj.image_data, input_text or input_ids, obj - ) - if image_inputs and "input_ids" in image_inputs: - input_ids = image_inputs["input_ids"] - return_logprob = obj.return_logprob - logprob_start_len = obj.logprob_start_len - top_logprobs_num = obj.top_logprobs_num - else: - rid = obj.rid[index] - if isinstance(obj, RewardReqInput): - input_text = self._apply_chat_template(obj.conv[input_id_index]) - input_ids = self.tokenizer.encode(input_text) - elif obj.input_ids is None: - input_text = obj.text[input_id_index] - input_ids = self.tokenizer.encode(input_text) - else: - input_text = ( - obj.text[input_id_index] if obj.text is not None else None - ) - input_ids = obj.input_ids[input_id_index] - - sampling_params = self._get_sampling_params(obj.sampling_params[index]) - if self.is_generation: - image_inputs = await self.image_processor.process_images_async( - obj.image_data[index], input_text or input_ids, obj - ) - if image_inputs and "input_ids" in image_inputs: - input_ids = image_inputs["input_ids"] - return_logprob = obj.return_logprob[index] - logprob_start_len = obj.logprob_start_len[index] - top_logprobs_num = obj.top_logprobs_num[index] - - self._validate_input_length(input_ids) - - else: # A prefill request to cache the common prompt for parallel sampling - assert self.is_generation - if obj.text is not None: - if isinstance(obj.text, list): - input_text = obj.text[input_id_index] - rid = obj.rid[index] - else: - input_text = obj.text - rid = obj.rid[0] - if self.tokenizer is not None: - input_ids = self.tokenizer.encode(input_text) - else: - assert obj.input_ids is not None - input_ids = obj.input_ids - if isinstance(obj.input_ids, list) and isinstance( - obj.input_ids[0], list - ): - # when obj["input_ids"] is List[List[int]] - input_ids = obj.input_ids[input_id_index] - rid = obj.rid[index] - else: - input_ids = obj.input_ids - rid = obj.rid[0] - else: - input_text = None - if isinstance(obj.input_ids, list) and isinstance( - obj.input_ids[0], list - ): - # when obj["input_ids"] is List[List[int]] - input_ids = obj.input_ids[input_id_index] - rid = obj.rid[index] - else: - input_ids = obj.input_ids - rid = obj.rid[0] - - sampling_params = SamplingParams(**obj.sampling_params[0]) - sampling_params.max_new_tokens = 0 + if self.is_generation: image_inputs = await self.image_processor.process_images_async( - obj.image_data[0], input_text or input_ids, obj + obj.image_data, input_text or input_ids, obj ) if image_inputs and "input_ids" in image_inputs: input_ids = image_inputs["input_ids"] - return_logprob = obj.return_logprob[0] - logprob_start_len = obj.logprob_start_len[0] - top_logprobs_num = obj.top_logprobs_num[0] + return_logprob = obj.return_logprob + logprob_start_len = obj.logprob_start_len + top_logprobs_num = obj.top_logprobs_num - # Send to the controller - if self.is_generation: + if len(input_ids) >= self.context_len: + raise ValueError( + f"The input ({len(input_ids)} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)." + ) + + # Parse sampling parameters + sampling_params = SamplingParams(**obj.sampling_params) + sampling_params.normalize(self.tokenizer) + sampling_params.verify() + + # Build return object + if isinstance(obj, GenerateReqInput): tokenized_obj = TokenizedGenerateReqInput( - rid, + obj.rid, input_text, input_ids, image_inputs, @@ -296,219 +226,126 @@ class TokenizerManager: logprob_start_len, top_logprobs_num, obj.stream, - ( - obj.lora_path[input_id_index] - if isinstance(obj.lora_path, list) - else obj.lora_path - ), + obj.lora_path ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( - rid, - input_text, - input_ids, - sampling_params, - ) - else: - assert isinstance(obj, RewardReqInput) - tokenized_obj = TokenizedRewardReqInput( - rid, + obj.rid, input_text, input_ids, sampling_params, ) - self.send_to_scheduler.send_pyobj(tokenized_obj) - return rid, input_ids + return tokenized_obj - async def _handle_single_request( + async def _wait_one_response( self, - obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], + obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, - index: Optional[int] = None, - input_id_index: Optional[int] = None, - is_cache_for_prefill: Optional[bool] = False, ): - rid, input_ids = await self._send_single_request( - obj, - index, - input_id_index=input_id_index, - is_cache_for_prefill=is_cache_for_prefill, - ) - - # Recv results + """Wait for the response of one request.""" event = asyncio.Event() state = ReqState([], False, event) - self.rid_to_state[rid] = state + self.rid_to_state[obj.rid] = state - if not is_cache_for_prefill: - async for response in self._wait_for_response(state, obj, rid, request): - yield response - else: - await state.event.wait() - assert state.finished - del self.rid_to_state[rid] - yield input_ids - - async def _handle_batch_request( - self, - obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], - request: Optional[fastapi.Request] = None, - ): - batch_size = obj.batch_size - if self.is_generation: - parallel_sample_num = obj.parallel_sample_num - - if parallel_sample_num != 1: - # Send prefill requests to cache the common prefix - parallel_sample_num += 1 - input_id_result = [] if obj.input_ids is None else None - for i in range(batch_size): - async for input_id in self._handle_single_request( - obj, - request, - index=i, - input_id_index=i, - is_cache_for_prefill=True, - ): - if input_id_result is not None: - input_id_result.append(input_id) - if input_id_result is not None: - obj.input_ids = input_id_result - else: - parallel_sample_num = 1 - - # First send out all requests - generators = [] - for i in range(batch_size): - for j in range(parallel_sample_num): - if j == 0 and parallel_sample_num != 1: - continue - index = i * parallel_sample_num + j - if parallel_sample_num != 1: - # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 - index += batch_size - 1 - i - - rid, _ = await self._send_single_request( - obj, index, input_id_index=i, is_cache_for_prefill=False - ) - - event = asyncio.Event() - state = ReqState([], False, event) - self.rid_to_state[rid] = state - - generators.append( - self._wait_for_response( - state, - obj, - rid, - request, - index=index, - response_index=len(generators), - ) - ) - - # Then process the responses based on streaming option - is_stream = hasattr(obj, "stream") and obj.stream - - tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] - output_list = [None] * len(tasks) - - # Fetch results - while tasks: - done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - - for task in done: - cur_index = tasks.index(task) - - try: - result = task.result() - - if is_stream: - yield result - else: - output_list[result["index"]] = result - - tasks[cur_index] = asyncio.create_task( - generators[cur_index].__anext__() - ) - except StopAsyncIteration: - del generators[cur_index] - del tasks[cur_index] - - if not is_stream: - yield output_list - - def _validate_input_length(self, input_ids: List[int]): - if len(input_ids) >= self.context_len: - raise ValueError( - f"The input ({len(input_ids)} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." - ) - - def _get_sampling_params(self, sampling_params_data: dict): - sampling_params = SamplingParams(**sampling_params_data) - if sampling_params.max_new_tokens != 0: - sampling_params.normalize(self.tokenizer) - sampling_params.verify() - return sampling_params - - def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]: - if isinstance(conv, str): - return conv - elif isinstance(conv, list): - if isinstance(conv[0], str): - return conv - else: - return self.tokenizer.apply_chat_template(conv, tokenize=False) - - async def _wait_for_response( - self, - state: ReqState, - obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], - rid: str, - request: Optional[fastapi.Request] = None, - index: Optional[int] = None, - response_index: int = 0, - ): while True: try: await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): - for rid in [obj.rid] if obj.is_single else obj.rid: - self.abort_request(rid) - raise ValueError(f"Abort request {rid}") + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") continue - if self.is_generation: + if isinstance(obj, GenerateReqInput): out = self.convert_logprob_style( state.out_list[-1], - obj.return_logprob if index is None else obj.return_logprob[index], - ( - obj.top_logprobs_num - if index is None - else obj.top_logprobs_num[index] - ), + obj.return_logprob, + obj.top_logprobs_num, obj.return_text_in_logprobs, ) - else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput)) + else: # isinstance(obj, (EmbeddingReqInput,)) out = state.out_list[-1] - out["index"] = response_index - state.out_list = [] if state.finished: - # Log requests if self.server_args.log_requests: + # Log requests logger.info(f"in={obj}, out={out}") - del self.rid_to_state[rid] + del self.rid_to_state[obj.rid] yield out break state.event.clear() yield out + async def _handle_batch_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, + ): + batch_size = obj.batch_size + + generators = [] + rids = [] + if getattr(obj, "parallel_sample_num", 1) == 1: + # Send all requests + for i in range(batch_size): + tmp_obj = obj[i] + tokenized_obj = await self._tokenize_one_request(tmp_obj) + self.send_to_scheduler.send_pyobj(tokenized_obj) + generators.append(self._wait_one_response(tmp_obj, request)) + rids.append(tmp_obj.rid) + else: + # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. + + # Tokenize all requests + objs = [obj[i] for i in range(batch_size)] + tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs)) + + # Cache the common prefix for parallel sampling + for i in range(batch_size): + tmp_obj = copy.copy(objs[i]) + tokenized_obj = copy.copy(tokenized_objs[i]) + tokenized_obj.rid = tmp_obj.regenerate_rid() + tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) + tokenized_obj.sampling_params.max_new_tokens = 0 + tokenized_obj.stream = False + self.send_to_scheduler.send_pyobj(tokenized_obj) + await self._wait_one_response(tmp_obj, request).__anext__() + + # Expand requests, assign new rids for them, and send them + for i in range(batch_size): + for _ in range(obj.parallel_sample_num): + tmp_obj = copy.copy(objs[i]) + tokenized_obj = copy.copy(tokenized_objs[i]) + tokenized_obj.rid = tmp_obj.regenerate_rid() + self.send_to_scheduler.send_pyobj(tokenized_obj) + generators.append(self._wait_one_response(tmp_obj, request)) + rids.append(tmp_obj.rid) + + # Wait for all requests + is_stream = hasattr(obj, "stream") and obj.stream + if not is_stream: + outputs = await asyncio.gather(*(gen.__anext__() for gen in generators)) + yield outputs + else: + rid_to_index = {rid: i for i, rid in enumerate(rids)} + task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} + while task_map: + done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED) + + for task in done: + gen = task_map.pop(task) + try: + result = task.result() + result["index"] = rid_to_index[result["meta_info"]["id"]] + yield result + new_task = asyncio.create_task(gen.__anext__()) + task_map[new_task] = gen + except StopAsyncIteration: + pass + def flush_cache(self): req = FlushCacheReq() self.send_to_scheduler.send_pyobj(req) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 0b820a8b0..c1e399241 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import ( TopLogprob, UsageInfo, ) +from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ) except Exception as e: + logger.error(f"error: {get_exception_traceback()}") + responses = [] error_json = { "id": f"batch_req_{uuid.uuid4()}", "custom_id": request_data.get("custom_id"), @@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } except Exception as e: - logger.error("error in SGLang:", e) + logger.error(f"error: {e}") # Update batch status to "failed" retrieve_batch = batch_storage[batch_id] retrieve_batch.status = "failed" @@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str): def v1_generate_request( all_requests: List[CompletionRequest], request_ids: List[str] = None ): + if len(all_requests) > 1: + first_prompt_type = type(all_requests[0].prompt) + for request in all_requests: + assert ( + type(request.prompt) is first_prompt_type + ), "All prompts must be of the same type in file input settings" + if request.n > 1: + raise ValueError( + "Parallel sampling is not supported for completions from files" + ) + prompts = [] sampling_params_list = [] return_logprobs = [] logprob_start_lens = [] top_logprobs_nums = [] - # NOTE: with openai API, the prompt's logprobs are always not computed - first_prompt_type = type(all_requests[0].prompt) for request in all_requests: - assert ( - type(request.prompt) is first_prompt_type - ), "All prompts must be of the same type in file input settings" - if len(all_requests) > 1 and request.n > 1: - raise ValueError( - "Parallel sampling is not supported for completions from files" - ) + # NOTE: with openai API, the prompt's logprobs are always not computed if request.echo and request.logprobs: logger.warning( "Echo is not compatible with logprobs. " - "To compute logprobs of input prompt, please use SGLang /request API." + "To compute logprobs of input prompt, please use the native /generate API." ) - for request in all_requests: prompts.append(request.prompt) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": request.stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "json_schema": request.json_schema, + "n": request.n, + "ignore_eos": request.ignore_eos, + "no_stop_trim": request.no_stop_trim, + } + ) return_logprobs.append(request.logprobs is not None and request.logprobs > 0) logprob_start_lens.append(-1) top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) - sampling_params = [] - if isinstance(request.no_stop_trim, list): - num_reqs = len(request.prompt) - else: - num_reqs = 1 - for i in range(num_reqs): - sampling_params.append( - { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "min_new_tokens": request.min_tokens, - "stop": request.stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "json_schema": request.json_schema, - "n": request.n, - "ignore_eos": request.ignore_eos, - "no_stop_trim": ( - request.no_stop_trim - if not isinstance(request.no_stop_trim, list) - else request.no_stop_trim[i] - ), - } - ) - if num_reqs == 1: - sampling_params_list.append(sampling_params[0]) - else: - sampling_params_list.append(sampling_params) if len(all_requests) == 1: - prompt = prompts[0] - sampling_params_list = sampling_params_list[0] - logprob_start_lens = logprob_start_lens[0] - return_logprobs = return_logprobs[0] - top_logprobs_nums = top_logprobs_nums[0] - if isinstance(prompt, str) or isinstance(prompt[0], str): - prompt_kwargs = {"text": prompt} + if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): + prompt_kwargs = {"text": prompts[0]} else: - prompt_kwargs = {"input_ids": prompt} + prompt_kwargs = {"input_ids": prompts[0]} + sampling_params_list = sampling_params_list[0] + return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] + top_logprobs_nums = top_logprobs_nums[0] else: - if isinstance(prompts[0], str): + if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): prompt_kwargs = {"text": prompts} else: prompt_kwargs = {"input_ids": prompts} @@ -558,9 +548,7 @@ def v1_generate_request( rid=request_ids, ) - if len(all_requests) == 1: - return adapted_request, all_requests[0] - return adapted_request, all_requests + return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] def v1_generate_response(request, ret, tokenizer_manager, to_file=False): @@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): if isinstance(request, list) and request[idx].echo: echo = True text = request[idx].prompt + text - if (not isinstance(request, list)) and echo: + if echo and not isinstance(request, list): prompt_index = idx // request.n text = prompts[prompt_index] + text @@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): - index = content["index"] + index = content.get("index", 0) stream_buffer = stream_buffers.get(index, "") n_prev_token = n_prev_tokens.get(index, 0) @@ -945,19 +933,18 @@ def v1_chat_generate_request( sampling_params_list.append(sampling_params) image_data_list.append(image_data) - modalities_list.extend(modalities) + modalities_list.append(modalities) if len(all_requests) == 1: - input_ids = input_ids[0] - if isinstance(input_ids, str): - prompt_kwargs = {"text": input_ids} + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids[0]} else: - prompt_kwargs = {"input_ids": input_ids} + prompt_kwargs = {"input_ids": input_ids[0]} sampling_params_list = sampling_params_list[0] image_data_list = image_data_list[0] return_logprobs = return_logprobs[0] logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] - modalities_list = modalities_list[:1] + modalities_list = modalities_list[0] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} @@ -976,9 +963,8 @@ def v1_chat_generate_request( rid=request_ids, modalities=modalities_list, ) - if len(all_requests) == 1: - return adapted_request, all_requests[0] - return adapted_request, all_requests + + return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): @@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): - index = content["index"] + index = content.get("index", 0) is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 6c46551c1..854a0c658 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, - RewardReqInput, UpdateWeightReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process @@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) app = FastAPI() -tokenizer_manager = None +tokenizer_manager: TokenizerManager = None app.add_middleware( CORSMiddleware, @@ -254,7 +253,7 @@ app.post("/encode")(encode_request) app.put("/encode")(encode_request) -async def judge_request(obj: RewardReqInput, request: Request): +async def judge_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request.""" try: ret = await tokenizer_manager.generate_request(obj, request).__anext__() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f7277f03d..ffc8e84f2 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -8,7 +8,7 @@ suites = { "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", - "models/test_reward_models.py", + # "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py", "test_double_sparsity.py", diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index d3e21d04b..d6ae76b8a 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,3 +1,8 @@ +""" +python3 -m unittest test_openai_server.TestOpenAIServer.test_batch +python3 -m unittest test_openai_server.TestOpenAIServer.test_completion + +""" import json import time import unittest diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index a5dcde4a2..3631780da 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -1,3 +1,6 @@ +""" +python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample +""" import json import unittest diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index e1b5318c0..045f3100c 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -1,5 +1,6 @@ """ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode +python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample """ import json @@ -36,11 +37,17 @@ class TestSRTEndpoint(unittest.TestCase): return_text=False, n=1, stream=False, + batch=False, ): + if batch: + text = ["The capital of France is"] + else: + text = "The capital of France is" + response = requests.post( self.base_url + "/generate", json={ - "text": "The capital of France is", + "text": text, "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": 16, @@ -67,6 +74,9 @@ class TestSRTEndpoint(unittest.TestCase): def test_simple_decode(self): self.run_decode() + def test_simple_decode_batch(self): + self.run_decode(batch=True) + def test_parallel_sample(self): self.run_decode(n=3) diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index a5bf302e2..d70dac66f 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -1,6 +1,7 @@ """ Usage: python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch +python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion """ import base64