Fix the case when max_new_tokens is too large (#1025)

This commit is contained in:
Lianmin Zheng
2024-08-11 15:20:18 -07:00
committed by GitHub
parent 7b6a5332ca
commit d785412077

View File

@@ -18,12 +18,16 @@ limitations under the License.
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
# Clip the max new tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS = 4096
class PolicyScheduler:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
@@ -98,7 +102,7 @@ class PrefillAdder:
tree_cache: BasePrefixCache,
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: int,
rem_chunk_tokens: Optional[int],
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
@@ -126,7 +130,11 @@ class PrefillAdder:
):
self.rem_total_tokens -= sum(
[
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* new_token_ratio
for r in running_batch.reqs
]
)
@@ -151,7 +159,11 @@ class PrefillAdder:
self._prefill_one_req(
len(req.prefix_indices),
req.extend_input_len,
req.sampling_params.max_new_tokens if not truncated else 0,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)
# Return if chunked prefill not finished
@@ -168,7 +180,9 @@ class PrefillAdder:
self.rem_total_tokens += delta
def add_one_req(self, req: Req):
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
@@ -191,7 +205,9 @@ class PrefillAdder:
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len, input_tokens, req.sampling_params.max_new_tokens
prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill