Fix the case when max_new_tokens is too large (#1025)
This commit is contained in:
@@ -18,12 +18,16 @@ limitations under the License.
|
|||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import TreeNode
|
from sglang.srt.mem_cache.radix_cache import TreeNode
|
||||||
|
|
||||||
|
# Clip the max new tokens for the request whose max_new_tokens is very large.
|
||||||
|
# This can prevent the server from being too conservative.
|
||||||
|
CLIP_MAX_NEW_TOKENS = 4096
|
||||||
|
|
||||||
|
|
||||||
class PolicyScheduler:
|
class PolicyScheduler:
|
||||||
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
||||||
@@ -98,7 +102,7 @@ class PrefillAdder:
|
|||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
rem_total_tokens: int,
|
rem_total_tokens: int,
|
||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
rem_chunk_tokens: int,
|
rem_chunk_tokens: Optional[int],
|
||||||
):
|
):
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
self.rem_total_tokens = rem_total_tokens
|
self.rem_total_tokens = rem_total_tokens
|
||||||
@@ -126,7 +130,11 @@ class PrefillAdder:
|
|||||||
):
|
):
|
||||||
self.rem_total_tokens -= sum(
|
self.rem_total_tokens -= sum(
|
||||||
[
|
[
|
||||||
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
|
min(
|
||||||
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||||
|
CLIP_MAX_NEW_TOKENS,
|
||||||
|
)
|
||||||
|
* new_token_ratio
|
||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -151,7 +159,11 @@ class PrefillAdder:
|
|||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
len(req.prefix_indices),
|
len(req.prefix_indices),
|
||||||
req.extend_input_len,
|
req.extend_input_len,
|
||||||
req.sampling_params.max_new_tokens if not truncated else 0,
|
(
|
||||||
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
||||||
|
if not truncated
|
||||||
|
else 0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return if chunked prefill not finished
|
# Return if chunked prefill not finished
|
||||||
@@ -168,7 +180,9 @@ class PrefillAdder:
|
|||||||
self.rem_total_tokens += delta
|
self.rem_total_tokens += delta
|
||||||
|
|
||||||
def add_one_req(self, req: Req):
|
def add_one_req(self, req: Req):
|
||||||
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
|
total_tokens = req.extend_input_len + min(
|
||||||
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
||||||
|
)
|
||||||
input_tokens = req.extend_input_len
|
input_tokens = req.extend_input_len
|
||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
@@ -191,7 +205,9 @@ class PrefillAdder:
|
|||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
prefix_len, input_tokens, req.sampling_params.max_new_tokens
|
prefix_len,
|
||||||
|
input_tokens,
|
||||||
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
|
|||||||
Reference in New Issue
Block a user