From f95e6617576ee30597ecf1d5de7e6feada70394d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 27 Jul 2024 15:44:27 -0700 Subject: [PATCH] Fix max_tokens for OpenAI chat completion API (#766) --- .../srt/managers/controller/tp_worker.py | 34 ++++++++++--------- python/sglang/srt/openai_api/protocol.py | 2 +- python/sglang/srt/sampling_params.py | 9 ++--- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index b04f0aa2d..8a8cab974 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -98,17 +98,21 @@ class ModelTpServer: if server_args.max_prefill_tokens is None else server_args.max_prefill_tokens ) - self.max_running_requests = ( - self.max_total_num_tokens // 2 - if server_args.max_running_requests is None - else server_args.max_running_requests - ) self.max_running_requests = min( - self.max_running_requests, self.model_runner.req_to_token_pool.size - 1 + ( + self.max_total_num_tokens // 2 + if server_args.max_running_requests is None + else server_args.max_running_requests + ), + self.model_runner.req_to_token_pool.size - 1, ) self.int_token_logit_bias = torch.tensor( get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) ) + self.max_req_input_len = min( + self.model_config.context_len - 1, + self.max_total_num_tokens - 1, + ) set_random_seed(server_args.random_seed) # Print info @@ -295,18 +299,16 @@ class ModelTpServer: ) # Truncate prompts that are too long - req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1] + if len(req.origin_input_ids) >= self.max_req_input_len: + logger.warn( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated!!!" + ) + req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] req.sampling_params.max_new_tokens = min( - req.sampling_params.max_new_tokens, - self.model_config.context_len - 1 - len(req.origin_input_ids), - self.max_total_num_tokens - 128 - len(req.origin_input_ids), + req.sampling_params.max_new_tokens or 1 << 30, + self.max_req_input_len - 1 - len(req.origin_input_ids), ) - if req.sampling_params.max_new_tokens < 0: - req.origin_input_ids = req.origin_input_ids[ - : self.max_total_num_tokens - 128 - ] - logger.error("Request longer than memory pool size, truncated!!!") - self.forward_queue.append(req) def get_new_prefill_batch(self) -> Optional[Batch]: diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index b91179203..c34ebd32e 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -152,7 +152,7 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = None - max_tokens: Optional[int] = 16 + max_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 76b802886..28eb4537b 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -65,10 +65,11 @@ class SamplingParams: raise ValueError( "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." ) - if self.max_new_tokens < 0: - raise ValueError( - f"max_new_tokens must be at least 0, got {self.max_new_tokens}." - ) + if self.max_new_tokens is not None: + if self.max_new_tokens < 0: + raise ValueError( + f"max_new_tokens must be at least 0, got {self.max_new_tokens}." + ) def normalize(self, tokenizer): # Process stop strings