diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 3c57e1144..924bad9f3 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -1,20 +1,20 @@ """ Benchmark the throughput of using the offline LLM engine. This script does not launch a server. -It accepts the same arguments as launch_server.py and additional benchmark arguments +It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py). # Usage ## Sharegpt dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct ## Random dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random ## Shared prefix dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name generated-shared-prefix +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix ## Sharegpt dataset on runtime backend -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --backend runtime +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime """ import argparse @@ -23,7 +23,7 @@ import json import logging import random import time -from typing import List, Tuple +from typing import List, Optional, Tuple import numpy as np @@ -45,14 +45,15 @@ class BenchArgs: dataset_name: str = "sharegpt" dataset_path: str = "" num_prompts: int = 1000 - sharegpt_output_len: int = 256 - random_input_len: int = 256 - random_output_len: int = 256 + sharegpt_output_len: Optional[int] = None + random_input_len: int = 1024 + random_output_len: int = 1024 random_range_ratio: float = 0.0 - gen_num_groups: int = 8 + gen_num_groups: int = 64 gen_prompts_per_group: int = 16 - gen_system_prompt_len: int = 128 - gen_question_len: int = 256 + gen_system_prompt_len: int = 2048 + gen_question_len: int = 128 + gen_output_len: int = 256 disable_ignore_eos: bool = False seed: int = 1 @@ -129,6 +130,12 @@ class BenchArgs: default=BenchArgs.gen_question_len, help="Question length, used" "only for generate-shared-prefix", ) + parser.add_argument( + "--gen-output-len", + type=int, + default=BenchArgs.gen_output_len, + help="Target length in tokens for outputs in generated-shared-prefix dataset", + ) parser.add_argument( "--disable-ignore-eos", type=bool, @@ -139,12 +146,8 @@ class BenchArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace): - # use the default value's type to case the args into correct types. - attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] - print(attrs) - return cls( - **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} - ) + attrs = [attr.name for attr in dataclasses.fields(cls)] + return cls(**{attr: getattr(args, attr) for attr in attrs}) def throughput_test_once( @@ -224,6 +227,7 @@ def throughput_test( random.seed(bench_args.seed) np.random.seed(bench_args.seed) + # Read dataset input_requests = get_dataset(bench_args, tokenizer) warmup_requests = sample_random_requests( diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 68c672413..de69679a7 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1241,10 +1241,12 @@ if __name__ == "__main__": parser.add_argument( "--random-input-len", type=int, + default=1024, help="Number of input tokens per request, used only for random dataset.", ) parser.add_argument( "--random-output-len", + default=1024, type=int, help="Number of output tokens per request, used only for random dataset.", ) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 0387124df..5db8ce4f1 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -100,20 +100,6 @@ class DetokenizerManager: if isinstance(recv_obj, BatchEmbeddingOut): # If it is embedding model, no detokenization is needed. - self.send_to_tokenizer.send_pyobj( - BatchEmbeddingOut( - rids=recv_obj.rids, - embeddings=recv_obj.embeddings, - meta_info=recv_obj.meta_info, - finished_reason=recv_obj.finished_reason, - ) - ) - continue - elif isinstance(recv_obj, UpdateWeightReqOutput): - # If it is a weight update request, no detokenization is needed. - self.send_to_tokenizer.send_pyobj(recv_obj) - continue - elif isinstance(recv_obj, GetMemPoolSizeReqOutput): self.send_to_tokenizer.send_pyobj(recv_obj) continue else: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 176a1f2f5..7232fc2a7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -114,6 +114,9 @@ class Scheduler: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name ) + self.send_to_tokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name + ) if server_args.skip_tokenizer_init: # Directly send to the tokenizer/api @@ -127,6 +130,7 @@ class Scheduler: ) else: self.recv_from_tokenizer = None + self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) # Init tokenizer @@ -421,7 +425,7 @@ class Scheduler: self.abort_request(recv_req) elif isinstance(recv_req, UpdateWeightReqInput): success, message = self.update_weights(recv_req) - self.send_to_detokenizer.send_pyobj( + self.send_to_tokenizer.send_pyobj( UpdateWeightReqOutput(success, message) ) elif isinstance(recv_req, ProfileReq): @@ -430,7 +434,7 @@ class Scheduler: else: self.stop_profile() elif isinstance(recv_req, GetMemPoolSizeReq): - self.send_to_detokenizer.send_pyobj( + self.send_to_tokenizer.send_pyobj( GetMemPoolSizeReqOutput(self.max_total_num_tokens) ) else: diff --git a/test/srt/test_overlap_schedule.py b/test/srt/test_overlap_schedule.py index c3d4a570d..367d2acc8 100644 --- a/test/srt/test_overlap_schedule.py +++ b/test/srt/test_overlap_schedule.py @@ -11,16 +11,24 @@ from sglang.test.test_utils import run_mmlu_test class TestOverlapSchedule(unittest.TestCase): def test_no_radix_attention_chunked_prefill(self): - run_mmlu_test(disable_radix_cache=True, chunked_prefill_size=32) + run_mmlu_test( + disable_radix_cache=True, chunked_prefill_size=32, enable_overlap=True + ) def test_no_radix_attention_no_chunked_prefill(self): - run_mmlu_test(disable_radix_cache=True, chunked_prefill_size=-1) + run_mmlu_test( + disable_radix_cache=True, chunked_prefill_size=-1, enable_overlap=True + ) def test_radix_attention_chunked_prefill(self): - run_mmlu_test(disable_radix_cache=False, chunked_prefill_size=32) + run_mmlu_test( + disable_radix_cache=False, chunked_prefill_size=32, enable_overlap=True + ) def test_radix_attention_no_chunked_prefill(self): - run_mmlu_test(disable_radix_cache=False, chunked_prefill_size=-1) + run_mmlu_test( + disable_radix_cache=False, chunked_prefill_size=-1, enable_overlap=True + ) if __name__ == "__main__":