Update run_batch interface and max_prefill_tokens (#574)
This commit is contained in:
@@ -120,6 +120,7 @@ class SglFunction:
|
||||
argspec = inspect.getfullargspec(func)
|
||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||
self.arg_names = argspec.args[1:]
|
||||
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
||||
|
||||
def bind(self, **kwargs):
|
||||
assert all(key in self.arg_names for key in kwargs)
|
||||
@@ -178,7 +179,18 @@ class SglFunction:
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
if not isinstance(batch_kwargs[0], dict):
|
||||
num_programs = len(batch_kwargs)
|
||||
# change the list of argument values to dict of arg_name -> arg_value
|
||||
batch_kwargs = [
|
||||
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
||||
for arg_values in batch_kwargs
|
||||
if isinstance(arg_values, (list, tuple)) and
|
||||
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
|
||||
]
|
||||
# Ensure to raise an exception if the number of arguments mismatch
|
||||
if len(batch_kwargs) != num_programs:
|
||||
raise Exception("Given arguments mismatch the SGL function signature")
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
|
||||
Reference in New Issue
Block a user