From 8f2c522abac932a9d4146000213dd559c5136c26 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 16 Jan 2025 06:24:31 -0800 Subject: [PATCH] Improve benchmark scripts and error message printing (#2922) --- python/sglang/bench_offline_throughput.py | 37 +++++--- python/sglang/bench_serving.py | 95 ++++++++++--------- python/sglang/srt/managers/io_struct.py | 6 ++ python/sglang/srt/managers/scheduler.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 41 ++++++-- python/sglang/srt/server.py | 8 +- python/sglang/test/test_utils.py | 1 + test/srt/test_moe_ep.py | 4 +- 8 files changed, 125 insertions(+), 70 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f32063b41..54b042c11 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -39,14 +39,15 @@ class BenchArgs: dataset_path: str = "" num_prompts: int = 1000 sharegpt_output_len: Optional[int] = None + sharegpt_context_len: Optional[int] = None random_input_len: int = 1024 random_output_len: int = 1024 random_range_ratio: float = 0.0 - gen_num_groups: int = 64 - gen_prompts_per_group: int = 16 - gen_system_prompt_len: int = 2048 - gen_question_len: int = 128 - gen_output_len: int = 256 + gsp_num_groups: int = 64 + gsp_prompts_per_group: int = 16 + gsp_system_prompt_len: int = 2048 + gsp_question_len: int = 128 + gsp_output_len: int = 256 disable_ignore_eos: bool = False extra_request_body: Optional[str] = None seed: int = 1 @@ -82,6 +83,12 @@ class BenchArgs: default=BenchArgs.sharegpt_output_len, help="Output length for each request. Overrides the output length from the ShareGPT dataset.", ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=BenchArgs.sharegpt_context_len, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) parser.add_argument( "--random-input-len", type=int, @@ -102,35 +109,35 @@ class BenchArgs: "used only for random dataset.", ) parser.add_argument( - "--gen-num-groups", + "--gsp-num-groups", type=int, - default=BenchArgs.gen_num_groups, + default=BenchArgs.gsp_num_groups, help="Number of groups with shared prefix, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-prompts-per-group", + "--gsp-prompts-per-group", type=int, - default=BenchArgs.gen_prompts_per_group, + default=BenchArgs.gsp_prompts_per_group, help="Number of prompts per group of shared prefix, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-system-prompt-len", + "--gsp-system-prompt-len", type=int, - default=BenchArgs.gen_system_prompt_len, + default=BenchArgs.gsp_system_prompt_len, help="System prompt length, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-question-len", + "--gsp-question-len", type=int, - default=BenchArgs.gen_question_len, + default=BenchArgs.gsp_question_len, help="Question length, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-output-len", + "--gsp-output-len", type=int, - default=BenchArgs.gen_output_len, + default=BenchArgs.gsp_output_len, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) parser.add_argument( diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 941507705..991b4ddcf 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -452,6 +452,7 @@ def get_dataset(args, tokenizer): num_requests=args.num_prompts, tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, + context_len=args.sharegpt_context_len, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -464,11 +465,11 @@ def get_dataset(args, tokenizer): ) elif args.dataset_name == "generated-shared-prefix": input_requests = sample_generated_shared_prefix_requests( - num_groups=args.gen_num_groups, - prompts_per_group=args.gen_prompts_per_group, - system_prompt_len=args.gen_system_prompt_len, - question_len=args.gen_question_len, - output_len=args.gen_output_len, + num_groups=args.gsp_num_groups, + prompts_per_group=args.gsp_prompts_per_group, + system_prompt_len=args.gsp_system_prompt_len, + question_len=args.gsp_question_len, + output_len=args.gsp_output_len, tokenizer=tokenizer, ) else: @@ -560,6 +561,7 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, + context_len: Optional[int] = None, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -597,14 +599,15 @@ def sample_sharegpt_requests( output_len = ( len(completion_token_ids) if fixed_output_len is None else fixed_output_len ) - if prompt_len < 4 or output_len < 4: + + if prompt_len < 1 or output_len < 1: # Prune too short sequences. continue - if prompt_len > 1024 or ( - prompt_len + output_len > 2048 and fixed_output_len is None - ): + + if context_len and prompt_len + output_len > context_len: # Prune too long sequences. continue + filtered_dataset.append((prompt, prompt_len, output_len)) print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") @@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer): # Create a unique cache filename based on the generation parameters cache_key = ( - f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_" - f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_" + f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_" + f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_" f"{tokenizer.__class__.__name__}.pkl" ) return cache_dir / cache_key @@ -1374,6 +1377,12 @@ if __name__ == "__main__": default=None, help="Output length for each request. Overrides the output length from the ShareGPT dataset.", ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=None, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) parser.add_argument( "--random-input-len", type=int, @@ -1453,38 +1462,6 @@ if __name__ == "__main__": help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) - - group = parser.add_argument_group("generated-shared-prefix dataset arguments") - group.add_argument( - "--gen-num-groups", - type=int, - default=64, - help="Number of system prompt groups for generated-shared-prefix dataset", - ) - group.add_argument( - "--gen-prompts-per-group", - type=int, - default=16, - help="Number of prompts per system prompt group for generated-shared-prefix dataset", - ) - group.add_argument( - "--gen-system-prompt-len", - type=int, - default=2048, - help="Target length in tokens for system prompts in generated-shared-prefix dataset", - ) - group.add_argument( - "--gen-question-len", - type=int, - default=128, - help="Target length in tokens for questions in generated-shared-prefix dataset", - ) - group.add_argument( - "--gen-output-len", - type=int, - default=256, - help="Target length in tokens for outputs in generated-shared-prefix dataset", - ) parser.add_argument( "--profile", action="store_true", @@ -1497,5 +1474,37 @@ if __name__ == "__main__": default=None, help="The name of LoRA adapter", ) + + group = parser.add_argument_group("generated-shared-prefix dataset arguments") + group.add_argument( + "--gsp-num-groups", + type=int, + default=64, + help="Number of system prompt groups for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-prompts-per-group", + type=int, + default=16, + help="Number of prompts per system prompt group for generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-system-prompt-len", + type=int, + default=2048, + help="Target length in tokens for system prompts in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-question-len", + type=int, + default=128, + help="Target length in tokens for questions in generated-shared-prefix dataset", + ) + group.add_argument( + "--gsp-output-len", + type=int, + default=256, + help="Target length in tokens for outputs in generated-shared-prefix dataset", + ) args = parser.parse_args() run_benchmark(args) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1698dfbeb..7f0705513 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -59,6 +59,9 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None # LoRA related @@ -196,6 +199,7 @@ class GenerateReqInput: top_logprobs_num=self.top_logprobs_num[i], return_text_in_logprobs=self.return_text_in_logprobs, stream=self.stream, + log_metrics=self.log_metrics, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, ) @@ -243,6 +247,8 @@ class EmbeddingReqInput: sampling_params: Union[List[Dict], Dict] = None # Dummy input embeds for compatibility input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 169c202d3..6ee93b3cd 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -631,7 +631,8 @@ class Scheduler: if len(req.origin_input_ids) > self.max_req_input_len: logger.warning( "Request length is longer than the KV cache pool size or " - "the max context length. Truncated!!!" + "the max context length. Truncated. " + f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}." ) req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4f4e4f7dc..4e120f3a9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -79,6 +79,7 @@ from sglang.srt.utils import ( get_zmq_socket, kill_process_tree, ) +from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -640,7 +641,9 @@ class TokenizerManager: self.to_create_loop = False loop = asyncio.get_event_loop() - self.asyncio_tasks.add(loop.create_task(self.handle_loop())) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.handle_loop)) + ) # We cannot add signal handler when the tokenizer manager is not in # the main thread due to the CPython limitation. @@ -653,7 +656,9 @@ class TokenizerManager: "not in the main thread. This disables graceful shutdown of the " "tokenizer manager when SIGTERM is received." ) - self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog())) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) + ) async def sigterm_watchdog(self): while not self.gracefully_exit: @@ -738,9 +743,13 @@ class TokenizerManager: state.finished = recv_obj.finished_reasons[i] is not None state.event.set() - if self.enable_metrics: + if self.enable_metrics and state.obj.log_metrics: self.collect_metrics(state, recv_obj, i) - if self.dump_requests_folder and state.finished: + if ( + self.dump_requests_folder + and state.finished + and state.obj.log_metrics + ): self.dump_requests(state, out_dict) elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( @@ -887,20 +896,38 @@ class TokenizerManager: ) if len(self.dump_request_list) >= self.dump_requests_threshold: + filename = os.path.join( + self.dump_requests_folder, + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", + ) + logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") + to_dump = self.dump_request_list self.dump_request_list = [] def background_task(): os.makedirs(self.dump_requests_folder, exist_ok=True) - current_time = datetime.now() - filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl" - with open(os.path.join(self.dump_requests_folder, filename), "wb") as f: + with open(filename, "wb") as f: pickle.dump(to_dump, f) # Schedule the task to run in the background without awaiting it asyncio.create_task(asyncio.to_thread(background_task)) +async def print_exception_wrapper(func): + """ + Sometimes an asyncio function does not print exception. + We do another wrapper to handle the exception. + """ + try: + await func() + except Exception: + traceback = get_exception_traceback() + logger.error(f"TokenizerManager hit an exception: {traceback}") + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(1) + + class SignalHandler: def __init__(self, tokenizer_manager): self.tokenizer_manager = tokenizer_manager diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 93fe1304c..6b180039e 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response: sampling_params = {"max_new_tokens": 1, "temperature": 0.7} if tokenizer_manager.is_generation: - gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params) + gri = GenerateReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) else: - gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params) + gri = EmbeddingReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) try: async for _ in tokenizer_manager.generate_request(gri, request): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 4121deb17..42e0b6d80 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -560,6 +560,7 @@ def run_bench_serving( tokenizer=tokenizer, num_prompts=num_prompts, sharegpt_output_len=None, + sharegpt_context_len=None, random_input_len=random_input_len, random_output_len=random_output_len, random_range_ratio=0.0, diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py index 4d9fd435e..9f87eb24d 100644 --- a/test/srt/test_moe_ep.py +++ b/test/srt/test_moe_ep.py @@ -44,7 +44,7 @@ class TestEpMoE(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + self.assertGreater(metrics["score"], 0.5) def test_mgsm_en(self): args = SimpleNamespace( @@ -56,7 +56,7 @@ class TestEpMoE(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.8 + self.assertGreater(metrics["score"], 0.8) class TestEpMoEFP8(unittest.TestCase):