Fix the case when max_new_tokens is too large (#1025)

This commit is contained in:
Lianmin Zheng
2024-08-11 15:20:18 -07:00
committed by GitHub
parent 7b6a5332ca
commit d785412077

View File

@@ -18,12 +18,16 @@ limitations under the License.
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List from typing import Dict, List, Optional
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode from sglang.srt.mem_cache.radix_cache import TreeNode
# Clip the max new tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS = 4096
class PolicyScheduler: class PolicyScheduler:
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(self, policy: str, tree_cache: BasePrefixCache):
@@ -98,7 +102,7 @@ class PrefillAdder:
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
rem_total_tokens: int, rem_total_tokens: int,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: int, rem_chunk_tokens: Optional[int],
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens self.rem_total_tokens = rem_total_tokens
@@ -126,7 +130,11 @@ class PrefillAdder:
): ):
self.rem_total_tokens -= sum( self.rem_total_tokens -= sum(
[ [
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* new_token_ratio
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
@@ -151,7 +159,11 @@ class PrefillAdder:
self._prefill_one_req( self._prefill_one_req(
len(req.prefix_indices), len(req.prefix_indices),
req.extend_input_len, req.extend_input_len,
req.sampling_params.max_new_tokens if not truncated else 0, (
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
) )
# Return if chunked prefill not finished # Return if chunked prefill not finished
@@ -168,7 +180,9 @@ class PrefillAdder:
self.rem_total_tokens += delta self.rem_total_tokens += delta
def add_one_req(self, req: Req): def add_one_req(self, req: Req):
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
input_tokens = req.extend_input_len input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
@@ -191,7 +205,9 @@ class PrefillAdder:
self.can_run_list.append(req) self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req( self._prefill_one_req(
prefix_len, input_tokens, req.sampling_params.max_new_tokens prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
) )
else: else:
# Chunked prefill # Chunked prefill