Fix the default arguments of bench_offline_throughput.py & simplify detokenizer manager (#2042)
This commit is contained in:
@@ -1,20 +1,20 @@
|
|||||||
"""
|
"""
|
||||||
Benchmark the throughput of using the offline LLM engine.
|
Benchmark the throughput of using the offline LLM engine.
|
||||||
This script does not launch a server.
|
This script does not launch a server.
|
||||||
It accepts the same arguments as launch_server.py and additional benchmark arguments
|
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
## Sharegpt dataset with default args
|
## Sharegpt dataset with default args
|
||||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
|
||||||
## Random dataset with default args
|
## Random dataset with default args
|
||||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
|
||||||
|
|
||||||
## Shared prefix dataset with default args
|
## Shared prefix dataset with default args
|
||||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name generated-shared-prefix
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
|
||||||
|
|
||||||
## Sharegpt dataset on runtime backend
|
## Sharegpt dataset on runtime backend
|
||||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --backend runtime
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -23,7 +23,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -45,14 +45,15 @@ class BenchArgs:
|
|||||||
dataset_name: str = "sharegpt"
|
dataset_name: str = "sharegpt"
|
||||||
dataset_path: str = ""
|
dataset_path: str = ""
|
||||||
num_prompts: int = 1000
|
num_prompts: int = 1000
|
||||||
sharegpt_output_len: int = 256
|
sharegpt_output_len: Optional[int] = None
|
||||||
random_input_len: int = 256
|
random_input_len: int = 1024
|
||||||
random_output_len: int = 256
|
random_output_len: int = 1024
|
||||||
random_range_ratio: float = 0.0
|
random_range_ratio: float = 0.0
|
||||||
gen_num_groups: int = 8
|
gen_num_groups: int = 64
|
||||||
gen_prompts_per_group: int = 16
|
gen_prompts_per_group: int = 16
|
||||||
gen_system_prompt_len: int = 128
|
gen_system_prompt_len: int = 2048
|
||||||
gen_question_len: int = 256
|
gen_question_len: int = 128
|
||||||
|
gen_output_len: int = 256
|
||||||
disable_ignore_eos: bool = False
|
disable_ignore_eos: bool = False
|
||||||
seed: int = 1
|
seed: int = 1
|
||||||
|
|
||||||
@@ -129,6 +130,12 @@ class BenchArgs:
|
|||||||
default=BenchArgs.gen_question_len,
|
default=BenchArgs.gen_question_len,
|
||||||
help="Question length, used" "only for generate-shared-prefix",
|
help="Question length, used" "only for generate-shared-prefix",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen-output-len",
|
||||||
|
type=int,
|
||||||
|
default=BenchArgs.gen_output_len,
|
||||||
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-ignore-eos",
|
"--disable-ignore-eos",
|
||||||
type=bool,
|
type=bool,
|
||||||
@@ -139,12 +146,8 @@ class BenchArgs:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
# use the default value's type to case the args into correct types.
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
print(attrs)
|
|
||||||
return cls(
|
|
||||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def throughput_test_once(
|
def throughput_test_once(
|
||||||
@@ -224,6 +227,7 @@ def throughput_test(
|
|||||||
random.seed(bench_args.seed)
|
random.seed(bench_args.seed)
|
||||||
np.random.seed(bench_args.seed)
|
np.random.seed(bench_args.seed)
|
||||||
|
|
||||||
|
# Read dataset
|
||||||
input_requests = get_dataset(bench_args, tokenizer)
|
input_requests = get_dataset(bench_args, tokenizer)
|
||||||
|
|
||||||
warmup_requests = sample_random_requests(
|
warmup_requests = sample_random_requests(
|
||||||
|
|||||||
@@ -1241,10 +1241,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
|
default=1024,
|
||||||
help="Number of input tokens per request, used only for random dataset.",
|
help="Number of input tokens per request, used only for random dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-output-len",
|
"--random-output-len",
|
||||||
|
default=1024,
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of output tokens per request, used only for random dataset.",
|
help="Number of output tokens per request, used only for random dataset.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -100,20 +100,6 @@ class DetokenizerManager:
|
|||||||
|
|
||||||
if isinstance(recv_obj, BatchEmbeddingOut):
|
if isinstance(recv_obj, BatchEmbeddingOut):
|
||||||
# If it is embedding model, no detokenization is needed.
|
# If it is embedding model, no detokenization is needed.
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
BatchEmbeddingOut(
|
|
||||||
rids=recv_obj.rids,
|
|
||||||
embeddings=recv_obj.embeddings,
|
|
||||||
meta_info=recv_obj.meta_info,
|
|
||||||
finished_reason=recv_obj.finished_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, UpdateWeightReqOutput):
|
|
||||||
# If it is a weight update request, no detokenization is needed.
|
|
||||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
|
||||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -114,6 +114,9 @@ class Scheduler:
|
|||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||||
)
|
)
|
||||||
|
self.send_to_tokenizer = get_zmq_socket(
|
||||||
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
# Directly send to the tokenizer/api
|
# Directly send to the tokenizer/api
|
||||||
@@ -127,6 +130,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.recv_from_tokenizer = None
|
self.recv_from_tokenizer = None
|
||||||
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||||
|
|
||||||
# Init tokenizer
|
# Init tokenizer
|
||||||
@@ -421,7 +425,7 @@ class Scheduler:
|
|||||||
self.abort_request(recv_req)
|
self.abort_request(recv_req)
|
||||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||||
success, message = self.update_weights(recv_req)
|
success, message = self.update_weights(recv_req)
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
UpdateWeightReqOutput(success, message)
|
UpdateWeightReqOutput(success, message)
|
||||||
)
|
)
|
||||||
elif isinstance(recv_req, ProfileReq):
|
elif isinstance(recv_req, ProfileReq):
|
||||||
@@ -430,7 +434,7 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
self.stop_profile()
|
self.stop_profile()
|
||||||
elif isinstance(recv_req, GetMemPoolSizeReq):
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
||||||
self.send_to_detokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -11,16 +11,24 @@ from sglang.test.test_utils import run_mmlu_test
|
|||||||
|
|
||||||
class TestOverlapSchedule(unittest.TestCase):
|
class TestOverlapSchedule(unittest.TestCase):
|
||||||
def test_no_radix_attention_chunked_prefill(self):
|
def test_no_radix_attention_chunked_prefill(self):
|
||||||
run_mmlu_test(disable_radix_cache=True, chunked_prefill_size=32)
|
run_mmlu_test(
|
||||||
|
disable_radix_cache=True, chunked_prefill_size=32, enable_overlap=True
|
||||||
|
)
|
||||||
|
|
||||||
def test_no_radix_attention_no_chunked_prefill(self):
|
def test_no_radix_attention_no_chunked_prefill(self):
|
||||||
run_mmlu_test(disable_radix_cache=True, chunked_prefill_size=-1)
|
run_mmlu_test(
|
||||||
|
disable_radix_cache=True, chunked_prefill_size=-1, enable_overlap=True
|
||||||
|
)
|
||||||
|
|
||||||
def test_radix_attention_chunked_prefill(self):
|
def test_radix_attention_chunked_prefill(self):
|
||||||
run_mmlu_test(disable_radix_cache=False, chunked_prefill_size=32)
|
run_mmlu_test(
|
||||||
|
disable_radix_cache=False, chunked_prefill_size=32, enable_overlap=True
|
||||||
|
)
|
||||||
|
|
||||||
def test_radix_attention_no_chunked_prefill(self):
|
def test_radix_attention_no_chunked_prefill(self):
|
||||||
run_mmlu_test(disable_radix_cache=False, chunked_prefill_size=-1)
|
run_mmlu_test(
|
||||||
|
disable_radix_cache=False, chunked_prefill_size=-1, enable_overlap=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user