Update max_req_len and max_req_input_len (#1748)
This commit is contained in:
@@ -165,6 +165,7 @@ class Scheduler:
|
|||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
|
self.max_req_len,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
self.device,
|
self.device,
|
||||||
@@ -421,13 +422,14 @@ class Scheduler:
|
|||||||
"the max context length. Truncated!!!"
|
"the max context length. Truncated!!!"
|
||||||
)
|
)
|
||||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||||
|
|
||||||
req.sampling_params.max_new_tokens = min(
|
req.sampling_params.max_new_tokens = min(
|
||||||
(
|
(
|
||||||
req.sampling_params.max_new_tokens
|
req.sampling_params.max_new_tokens
|
||||||
if req.sampling_params.max_new_tokens is not None
|
if req.sampling_params.max_new_tokens is not None
|
||||||
else 1 << 30
|
else 1 << 30
|
||||||
),
|
),
|
||||||
self.max_req_input_len - len(req.origin_input_ids),
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|||||||
@@ -90,10 +90,14 @@ class TpModelWorker:
|
|||||||
),
|
),
|
||||||
self.model_runner.req_to_token_pool.size,
|
self.model_runner.req_to_token_pool.size,
|
||||||
)
|
)
|
||||||
self.max_req_input_len = min(
|
self.max_req_len = min(
|
||||||
self.model_config.context_len - 1,
|
self.model_config.context_len - 1,
|
||||||
self.max_total_num_tokens - 1,
|
self.max_total_num_tokens - 1,
|
||||||
)
|
)
|
||||||
|
self.max_req_input_len = self.max_req_len - 5
|
||||||
|
assert (
|
||||||
|
self.max_req_len > 0 and self.max_req_input_len > 0
|
||||||
|
), "Memory pool size is too small"
|
||||||
|
|
||||||
# Sync random seed across TP workers
|
# Sync random seed across TP workers
|
||||||
self.random_seed = broadcast_pyobj(
|
self.random_seed = broadcast_pyobj(
|
||||||
@@ -108,6 +112,7 @@ class TpModelWorker:
|
|||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
|
self.max_req_len,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
self.device,
|
self.device,
|
||||||
|
|||||||
Reference in New Issue
Block a user