fix: fix max_new_tokens uninitialized error (#9343)
This commit is contained in:
@@ -1181,6 +1181,16 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
self.send_to_tokenizer.send_pyobj(output)
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
|
|
||||||
|
def init_req_max_new_tokens(self, req):
|
||||||
|
req.sampling_params.max_new_tokens = min(
|
||||||
|
(
|
||||||
|
req.sampling_params.max_new_tokens
|
||||||
|
if req.sampling_params.max_new_tokens is not None
|
||||||
|
else 1 << 30
|
||||||
|
),
|
||||||
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||||
|
)
|
||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
@@ -1244,6 +1254,7 @@ class Scheduler(
|
|||||||
req.set_finish_with_abort(
|
req.set_finish_with_abort(
|
||||||
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
||||||
)
|
)
|
||||||
|
self.init_req_max_new_tokens(req)
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@@ -1251,6 +1262,7 @@ class Scheduler(
|
|||||||
session = self.sessions[recv_req.session_params.id]
|
session = self.sessions[recv_req.session_params.id]
|
||||||
req = session.create_req(recv_req, self.tokenizer)
|
req = session.create_req(recv_req, self.tokenizer)
|
||||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||||
|
self.init_req_max_new_tokens(req)
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1270,9 +1282,13 @@ class Scheduler(
|
|||||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.init_req_max_new_tokens(req)
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# initialize before returning
|
||||||
|
self.init_req_max_new_tokens(req)
|
||||||
|
|
||||||
# Validate prompt length
|
# Validate prompt length
|
||||||
error_msg = validate_input_length(
|
error_msg = validate_input_length(
|
||||||
req,
|
req,
|
||||||
@@ -1306,15 +1322,6 @@ class Scheduler(
|
|||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
req.sampling_params.max_new_tokens = min(
|
|
||||||
(
|
|
||||||
req.sampling_params.max_new_tokens
|
|
||||||
if req.sampling_params.max_new_tokens is not None
|
|
||||||
else 1 << 30
|
|
||||||
),
|
|
||||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init grammar cache for this request
|
# Init grammar cache for this request
|
||||||
add_to_grammar_queue = False
|
add_to_grammar_queue = False
|
||||||
if (
|
if (
|
||||||
|
|||||||
Reference in New Issue
Block a user