diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 03c575564..3e94ec811 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -43,6 +43,7 @@ I'm going to the park """ import argparse +import copy import dataclasses import itertools import json @@ -84,12 +85,14 @@ class BenchArgs: batch_size: Tuple[int] = (1,) input_len: Tuple[int] = (1024,) output_len: Tuple[int] = (16,) + prompt_filename: str = "" result_filename: str = "result.jsonl" correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 log_decode_step: int = 0 profile: bool = False + profile_record_shapes: bool = False profile_filename_prefix: str = "profile" @staticmethod @@ -104,6 +107,9 @@ class BenchArgs: parser.add_argument( "--output-len", type=int, nargs="+", default=BenchArgs.output_len ) + parser.add_argument( + "--prompt-filename", type=str, default=BenchArgs.prompt_filename + ) parser.add_argument( "--result-filename", type=str, default=BenchArgs.result_filename ) @@ -118,6 +124,11 @@ class BenchArgs: parser.add_argument( "--profile", action="store_true", help="Use Torch Profiler." ) + parser.add_argument( + "--profile-record-shapes", + action="store_true", + help="Record tensor shapes in profiling results.", + ) parser.add_argument( "--profile-filename-prefix", type=str, @@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank): return model_runner, tokenizer -def prepare_inputs_for_correctness_test(bench_args, tokenizer): - prompts = [ - "The capital of France is", - "The capital of the United Kindom is", - "Today is a sunny day and I like", - ] +def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): + prompts = ( + custom_prompts + if custom_prompts + else [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", + ] + ) input_ids = [tokenizer.encode(p) for p in prompts] sampling_params = SamplingParams( temperature=0, @@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test( return reqs -def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): - input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) +def prepare_synthetic_inputs_for_latency_test( + batch_size, input_len, custom_inputs=None +): + input_ids = ( + custom_inputs + if custom_inputs + else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) + ) sampling_params = SamplingParams( temperature=0, max_new_tokens=BenchArgs.output_len, @@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): ) +def _read_prompts_from_file(prompt_file, rank_print): + """Read custom prompts from the file specified by `--prompt-filename`.""" + if not prompt_file: + return [] + if not os.path.exists(prompt_file): + rank_print( + f"Custom prompt file {prompt_file} not found. Using default inputs..." + ) + return [] + with open(prompt_file, "r") as pf: + return pf.readlines() + + +def _save_profile_trace_results(profiler, filename): + parent_dir = os.path.dirname(os.path.abspath(filename)) + os.makedirs(parent_dir, exist_ok=True) + profiler.export_chrome_trace(filename) + print( + profiler.key_averages(group_by_input_shape=True).table( + sort_by="self_cpu_time_total" + ) + ) + + def correctness_test( server_args, port_args, @@ -298,7 +343,10 @@ def correctness_test( model_runner, tokenizer = load_model(server_args, port_args, tp_rank) # Prepare inputs - input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) + custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print) + input_ids, reqs = prepare_inputs_for_correctness_test( + bench_args, tokenizer, custom_prompts + ) rank_print(f"\n{input_ids=}\n") if bench_args.cut_len > 0: @@ -344,6 +392,7 @@ def latency_test_run_once( device, log_decode_step, profile, + profile_record_shapes, profile_filename_prefix, ): max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) @@ -374,6 +423,7 @@ def latency_test_run_once( torch.profiler.ProfilerActivity.CUDA, ], with_stack=True, + record_shapes=profile_record_shapes, ) profiler.start() @@ -391,10 +441,30 @@ def latency_test_run_once( measurement_results["prefill_latency"] = prefill_latency measurement_results["prefill_throughput"] = throughput + if profile: + profiler.stop() + profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz" + _save_profile_trace_results(profiler, profile_filename) + rank_print( + f"torch profiler chrome trace for prefill saved to {profile_filename}" + ) + # Decode decode_latencies = [] for i in range(output_len - 1): synchronize(device) + if profile and i == output_len / 2: + profiler = None + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=profile_record_shapes, + ) + profiler.start() + tic = time.perf_counter() next_token_ids, _ = decode(next_token_ids, batch, model_runner) synchronize(device) @@ -407,13 +477,13 @@ def latency_test_run_once( f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) - if profile: - profiler.stop() - profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" - parent_dir = os.path.dirname(os.path.abspath(profile_filename)) - os.makedirs(parent_dir, exist_ok=True) - profiler.export_chrome_trace(profile_filename) - rank_print(f"torch profiler chrome trace saved to {profile_filename}") + if profile and i == output_len / 2: + profiler.stop() + profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz" + _save_profile_trace_results(profiler, profile_filename) + rank_print( + f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}" + ) # Record decode timing from 2nd output if output_len > 1: @@ -469,17 +539,42 @@ def latency_test( server_args.device, log_decode_step=0, profile=False, + profile_record_shapes=False, profile_filename_prefix="", # not used ) rank_print("Benchmark ...") + custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print) + custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs] + custom_input_len = len(custom_inputs) + # Run the sweep result_list = [] for bs, il, ol in itertools.product( bench_args.batch_size, bench_args.input_len, bench_args.output_len ): - reqs = prepare_synthetic_inputs_for_latency_test(bs, il) + bs_aligned_inputs = [] + if custom_inputs: + if custom_input_len == bs: + bs_aligned_inputs = custom_inputs + elif custom_input_len > bs: + rank_print( + f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). " + f"Using the first {bs} prompts." + ) + bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs]) + else: + rank_print( + f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). " + f"Pad to the desired batch_size with the last prompt." + ) + bs_aligned_inputs = copy.deepcopy(custom_inputs) + bs_aligned_inputs.extend( + [bs_aligned_inputs[-1]] * (bs - custom_input_len) + ) + + reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs) ret = latency_test_run_once( bench_args.run_name, model_runner, @@ -491,6 +586,7 @@ def latency_test( server_args.device, bench_args.log_decode_step, bench_args.profile if tp_rank == 0 else None, + bench_args.profile_record_shapes if tp_rank == 0 else None, bench_args.profile_filename_prefix, ) if ret is not None: