Fix the case when max_new_tokens is too large (#1025)
This commit is contained in:
@@ -18,12 +18,16 @@ limitations under the License.
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.radix_cache import TreeNode
|
||||
|
||||
# Clip the max new tokens for the request whose max_new_tokens is very large.
|
||||
# This can prevent the server from being too conservative.
|
||||
CLIP_MAX_NEW_TOKENS = 4096
|
||||
|
||||
|
||||
class PolicyScheduler:
|
||||
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
||||
@@ -98,7 +102,7 @@ class PrefillAdder:
|
||||
tree_cache: BasePrefixCache,
|
||||
rem_total_tokens: int,
|
||||
rem_input_tokens: int,
|
||||
rem_chunk_tokens: int,
|
||||
rem_chunk_tokens: Optional[int],
|
||||
):
|
||||
self.tree_cache = tree_cache
|
||||
self.rem_total_tokens = rem_total_tokens
|
||||
@@ -126,7 +130,11 @@ class PrefillAdder:
|
||||
):
|
||||
self.rem_total_tokens -= sum(
|
||||
[
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
|
||||
min(
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||
CLIP_MAX_NEW_TOKENS,
|
||||
)
|
||||
* new_token_ratio
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
@@ -151,7 +159,11 @@ class PrefillAdder:
|
||||
self._prefill_one_req(
|
||||
len(req.prefix_indices),
|
||||
req.extend_input_len,
|
||||
req.sampling_params.max_new_tokens if not truncated else 0,
|
||||
(
|
||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
||||
if not truncated
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
# Return if chunked prefill not finished
|
||||
@@ -168,7 +180,9 @@ class PrefillAdder:
|
||||
self.rem_total_tokens += delta
|
||||
|
||||
def add_one_req(self, req: Req):
|
||||
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
total_tokens = req.extend_input_len + min(
|
||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
||||
)
|
||||
input_tokens = req.extend_input_len
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
@@ -191,7 +205,9 @@ class PrefillAdder:
|
||||
self.can_run_list.append(req)
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(
|
||||
prefix_len, input_tokens, req.sampling_params.max_new_tokens
|
||||
prefix_len,
|
||||
input_tokens,
|
||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
||||
)
|
||||
else:
|
||||
# Chunked prefill
|
||||
|
||||
Reference in New Issue
Block a user