Update run_batch interface and max_prefill_tokens (#574)

This commit is contained in:
Ying Sheng
2024-06-30 18:26:04 -07:00
committed by GitHub
parent 11616fc6bd
commit 75b31a2a88
3 changed files with 19 additions and 14 deletions

View File

@@ -120,6 +120,7 @@ class SglFunction:
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
@@ -178,7 +179,18 @@ class SglFunction:
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
if not isinstance(batch_kwargs[0], dict):
num_programs = len(batch_kwargs)
# change the list of argument values to dict of arg_name -> arg_value
batch_kwargs = [
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
for arg_values in batch_kwargs
if isinstance(arg_values, (list, tuple)) and
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
]
# Ensure to raise an exception if the number of arguments mismatch
if len(batch_kwargs) != num_programs:
raise Exception("Given arguments mismatch the SGL function signature")
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,

View File

@@ -98,10 +98,7 @@ class ModelTpServer:
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
max(
self.model_config.context_len,
min(self.max_total_num_tokens // 6, 32768),
)
4096
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
@@ -371,8 +368,9 @@ class ModelTpServer:
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_tokens
and (req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0)
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta