Fix max_tokens for OpenAI chat completion API (#766)
This commit is contained in:
@@ -98,17 +98,21 @@ class ModelTpServer:
|
|||||||
if server_args.max_prefill_tokens is None
|
if server_args.max_prefill_tokens is None
|
||||||
else server_args.max_prefill_tokens
|
else server_args.max_prefill_tokens
|
||||||
)
|
)
|
||||||
self.max_running_requests = (
|
self.max_running_requests = min(
|
||||||
|
(
|
||||||
self.max_total_num_tokens // 2
|
self.max_total_num_tokens // 2
|
||||||
if server_args.max_running_requests is None
|
if server_args.max_running_requests is None
|
||||||
else server_args.max_running_requests
|
else server_args.max_running_requests
|
||||||
)
|
),
|
||||||
self.max_running_requests = min(
|
self.model_runner.req_to_token_pool.size - 1,
|
||||||
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
|
|
||||||
)
|
)
|
||||||
self.int_token_logit_bias = torch.tensor(
|
self.int_token_logit_bias = torch.tensor(
|
||||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||||
)
|
)
|
||||||
|
self.max_req_input_len = min(
|
||||||
|
self.model_config.context_len - 1,
|
||||||
|
self.max_total_num_tokens - 1,
|
||||||
|
)
|
||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
# Print info
|
# Print info
|
||||||
@@ -295,18 +299,16 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
req.sampling_params.max_new_tokens = min(
|
logger.warn(
|
||||||
req.sampling_params.max_new_tokens,
|
"Request length is longer than the KV cache pool size or "
|
||||||
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
"the max context length. Truncated!!!"
|
||||||
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
)
|
||||||
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||||
|
req.sampling_params.max_new_tokens = min(
|
||||||
|
req.sampling_params.max_new_tokens or 1 << 30,
|
||||||
|
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||||
)
|
)
|
||||||
if req.sampling_params.max_new_tokens < 0:
|
|
||||||
req.origin_input_ids = req.origin_input_ids[
|
|
||||||
: self.max_total_num_tokens - 128
|
|
||||||
]
|
|
||||||
logger.error("Request longer than memory pool size, truncated!!!")
|
|
||||||
|
|
||||||
self.forward_queue.append(req)
|
self.forward_queue.append(req)
|
||||||
|
|
||||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
logit_bias: Optional[Dict[str, float]] = None
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
logprobs: Optional[bool] = False
|
logprobs: Optional[bool] = False
|
||||||
top_logprobs: Optional[int] = None
|
top_logprobs: Optional[int] = None
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = None
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class SamplingParams:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
||||||
)
|
)
|
||||||
|
if self.max_new_tokens is not None:
|
||||||
if self.max_new_tokens < 0:
|
if self.max_new_tokens < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||||
|
|||||||
Reference in New Issue
Block a user