Fix max_tokens for OpenAI chat completion API (#766)
This commit is contained in:
@@ -98,17 +98,21 @@ class ModelTpServer:
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
)
|
||||
self.max_running_requests = (
|
||||
self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
)
|
||||
self.max_running_requests = min(
|
||||
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
|
||||
(
|
||||
self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
),
|
||||
self.model_runner.req_to_token_pool.size - 1,
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
self.max_req_input_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
)
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
@@ -295,18 +299,16 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
logger.warn(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
req.sampling_params.max_new_tokens,
|
||||
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
||||
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
||||
req.sampling_params.max_new_tokens or 1 << 30,
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
if req.sampling_params.max_new_tokens < 0:
|
||||
req.origin_input_ids = req.origin_input_ids[
|
||||
: self.max_total_num_tokens - 128
|
||||
]
|
||||
logger.error("Request longer than memory pool size, truncated!!!")
|
||||
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
|
||||
Reference in New Issue
Block a user