Fix max_tokens for OpenAI chat completion API (#766)
This commit is contained in:
@@ -98,17 +98,21 @@ class ModelTpServer:
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
)
|
||||
self.max_running_requests = (
|
||||
self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
)
|
||||
self.max_running_requests = min(
|
||||
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
|
||||
(
|
||||
self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
),
|
||||
self.model_runner.req_to_token_pool.size - 1,
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
self.max_req_input_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
)
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
@@ -295,18 +299,16 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
logger.warn(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
req.sampling_params.max_new_tokens,
|
||||
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
||||
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
||||
req.sampling_params.max_new_tokens or 1 << 30,
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
if req.sampling_params.max_new_tokens < 0:
|
||||
req.origin_input_ids = req.origin_input_ids[
|
||||
: self.max_total_num_tokens - 128
|
||||
]
|
||||
logger.error("Request longer than memory pool size, truncated!!!")
|
||||
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
|
||||
@@ -152,7 +152,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
@@ -65,10 +65,11 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
||||
)
|
||||
if self.max_new_tokens < 0:
|
||||
raise ValueError(
|
||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||
)
|
||||
if self.max_new_tokens is not None:
|
||||
if self.max_new_tokens < 0:
|
||||
raise ValueError(
|
||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||
)
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
|
||||
Reference in New Issue
Block a user