diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 485242781..49727b121 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -30,8 +30,10 @@ import argparse import dataclasses import logging import multiprocessing +import os import time + import numpy as np import torch import torch.distributed as dist @@ -70,6 +72,7 @@ class BenchArgs: def load_model(server_args, tp_rank): suppress_other_loggers() + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None model_config = ModelConfig(path=server_args.model_path) model_runner = ModelRunner( @@ -81,7 +84,7 @@ def load_model(server_args, tp_rank): nccl_port=28888, server_args=server_args, ) - print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") + rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") tokenizer = get_tokenizer( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -201,7 +204,7 @@ def correctness_test( # Print for i in range(len(reqs)): - print(tokenizer.decode(output_ids[i])) + rank_print(tokenizer.decode(output_ids[i])) def latency_test( @@ -213,7 +216,7 @@ def latency_test( # Load the model model_runner, tokenizer = load_model(server_args, tp_rank) - print( + rank_print( f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" ) @@ -299,6 +302,8 @@ def main(server_args, bench_args): for proc in workers: proc.join() + proc.terminate() + if __name__ == "__main__": parser = argparse.ArgumentParser()