Update run_batch interface and max_prefill_tokens (#574)
This commit is contained in:
@@ -1,13 +1,8 @@
|
|||||||
## SRT Unit Tests
|
## SRT Unit Tests
|
||||||
|
|
||||||
### Low-level API
|
### Latency Alignment
|
||||||
```
|
```
|
||||||
cd sglang/test/srt/model
|
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
|
||||||
|
|
||||||
python3 test_llama_low_api.py
|
|
||||||
python3 test_llama_extend.py
|
|
||||||
python3 test_llava_low_api.py
|
|
||||||
python3 bench_llama_low_api.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### High-level API
|
### High-level API
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class SglFunction:
|
|||||||
argspec = inspect.getfullargspec(func)
|
argspec = inspect.getfullargspec(func)
|
||||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||||
self.arg_names = argspec.args[1:]
|
self.arg_names = argspec.args[1:]
|
||||||
|
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
||||||
|
|
||||||
def bind(self, **kwargs):
|
def bind(self, **kwargs):
|
||||||
assert all(key in self.arg_names for key in kwargs)
|
assert all(key in self.arg_names for key in kwargs)
|
||||||
@@ -178,7 +179,18 @@ class SglFunction:
|
|||||||
assert isinstance(batch_kwargs, (list, tuple))
|
assert isinstance(batch_kwargs, (list, tuple))
|
||||||
if len(batch_kwargs) == 0:
|
if len(batch_kwargs) == 0:
|
||||||
return []
|
return []
|
||||||
assert isinstance(batch_kwargs[0], dict)
|
if not isinstance(batch_kwargs[0], dict):
|
||||||
|
num_programs = len(batch_kwargs)
|
||||||
|
# change the list of argument values to dict of arg_name -> arg_value
|
||||||
|
batch_kwargs = [
|
||||||
|
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
||||||
|
for arg_values in batch_kwargs
|
||||||
|
if isinstance(arg_values, (list, tuple)) and
|
||||||
|
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
|
||||||
|
]
|
||||||
|
# Ensure to raise an exception if the number of arguments mismatch
|
||||||
|
if len(batch_kwargs) != num_programs:
|
||||||
|
raise Exception("Given arguments mismatch the SGL function signature")
|
||||||
|
|
||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
|||||||
@@ -98,10 +98,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||||
self.max_prefill_tokens = (
|
self.max_prefill_tokens = (
|
||||||
max(
|
4096
|
||||||
self.model_config.context_len,
|
|
||||||
min(self.max_total_num_tokens // 6, 32768),
|
|
||||||
)
|
|
||||||
if server_args.max_prefill_tokens is None
|
if server_args.max_prefill_tokens is None
|
||||||
else server_args.max_prefill_tokens
|
else server_args.max_prefill_tokens
|
||||||
)
|
)
|
||||||
@@ -371,8 +368,9 @@ class ModelTpServer:
|
|||||||
if (
|
if (
|
||||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||||
< available_size
|
< available_size
|
||||||
and req.extend_input_len + new_batch_input_tokens
|
and (req.extend_input_len + new_batch_input_tokens
|
||||||
< self.max_prefill_tokens
|
<= self.max_prefill_tokens
|
||||||
|
or len(can_run_list) == 0)
|
||||||
):
|
):
|
||||||
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
available_size += delta
|
available_size += delta
|
||||||
|
|||||||
Reference in New Issue
Block a user