"""Benchmark the latency of processing a single batch of requests.""" import argparse import dataclasses import json import time from pathlib import Path from typing import List, Optional import math import os os.environ['CN_NOTIFIER_POOL_MAX'] = "1000" import numpy as np import torch from tqdm import tqdm from common import init_logger from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.utils import FlexibleArgumentParser from vllm_mlu._mlu_utils import USE_PAGED, VLLM_DUMP_MLU_INFO_EN from vllm_mlu.dump_info import LLMDumpInfo logger = init_logger(__name__) def main(args: argparse.Namespace): print(args) # Only support input case list assert len(args.input_case_list) > 0, "Only support input case list." new_case_list = [] max_model_len = 0 max_num_batched_tokens = 0 for case in args.input_case_list: case_info = case.split(",") assert len(case_info) == 3 batch_size, input_len, output_len = [int(v) for v in case_info] new_case_list.append((batch_size, input_len, output_len)) cur_max_model_len = input_len + output_len if cur_max_model_len > max_model_len: max_model_len = cur_max_model_len cur_max_num_batched_tokens = batch_size * input_len if cur_max_num_batched_tokens > max_num_batched_tokens: max_num_batched_tokens = cur_max_num_batched_tokens if max_num_batched_tokens < max_model_len: max_num_batched_tokens = max_model_len args.max_model_len = max_model_len args.max_num_batched_tokens = max_num_batched_tokens args.max_seq_len_to_capture = max_model_len if not USE_PAGED: args.block_size = max_model_len logger.warning(f"For unpaged mode, we must choose the max-scale to set block_size, " + f"which may decreases the concurrency of small-scale.") engine_args = EngineArgs.from_cli_args(args) # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args), enable_context_mlugraph=True, context_batch_size_to_capture=new_case_list[0][0], context_seq_len_to_capture=new_case_list[0][1]) if VLLM_DUMP_MLU_INFO_EN: LLM.dump_info.dev_info.should_stop = True for batch_size, input_len, output_len in new_case_list: print("\n" + f"#" * 60 + "\n" + \ f"# Benchmark: batch_size={batch_size}, input_len={input_len}, output_len={output_len} #\n" + \ f"#" * 60 + "\n") # Re-Start dump info LLM.dump_info = LLMDumpInfo() LLM.dump_info.init_param( tensor_parallel_size=args.tensor_parallel_size, dtype=args.dtype, kv_cache_dtype=args.kv_cache_dtype, quantization=args.quantization, model=args.model, trust_remote_code=args.trust_remote_code ) LLM.dump_info.memory_usage() # Reset metrics llm.metric.reset_metric() # Re-capture model for context and decoder mlugraph llm.llm_engine.model_executor.recapture_model(batch_size, input_len) # Run current case num_gpu_block = llm.llm_engine.cache_config.num_gpu_blocks block_size = llm.llm_engine.cache_config.block_size max_num_batched_tokens = llm.llm_engine.scheduler_config.max_num_batched_tokens batched_input_tokens = input_len * batch_size batched_tokens_align = math.ceil((input_len + output_len) / \ block_size) * block_size * batch_size if not args.enable_chunked_prefill : if max_num_batched_tokens < batched_input_tokens : logger.error(f"The batch({batch_size}) * input length({input_len}) =" f" ({batched_input_tokens}) is larger than " f" max_num_batched_tokens({max_num_batched_tokens})") logger.info(f"Try --max-num-batched-tokens ({batched_input_tokens})") return elif num_gpu_block * block_size < batched_tokens_align : logger.error(f"Ceil of batch({batch_size}) * (input length" f" ({input_len}) + output length({output_len})) =" f" ({batched_tokens_align}) is larger than" f" mlu blocks({num_gpu_block}) * block_size({block_size}) =" f" ({num_gpu_block * block_size}) can hold max tokens.") if not USE_PAGED : logger.info("Try reduce block_size to make mlu blocks greater than batch," " or try increase -tp to get more mlu blocks.") else : logger.info("Try increase -tp to get more mlu blocks.") return # Generate a warning if the sum of the input length and output length # is less than the maximum model length, as only the first # `max_model_len` will be processed. max_length = input_len + output_len max_model_len = llm.llm_engine.model_config.max_model_len if max_length > max_model_len: logger.warning( f"The sum of input length({input_len}) and output" f" length({output_len}) is larger than max model" f" length({max_model_len})") sampling_params = SamplingParams( n=args.n, temperature=1.0, top_p=1.0, ignore_eos=True, max_tokens=output_len, ) print(sampling_params) dummy_prompt_token_ids = np.random.randint(10000, size=(batch_size, input_len)) dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.MLU, ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() latency = end_time - start_time return latency print("Warming up...") for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): run_to_completion(profile_dir=None) if args.profile: profile_dir = args.profile_result_dir if not profile_dir: profile_dir = Path( "." ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return # Benchmark. latencies = [] for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): latencies.append(run_to_completion(profile_dir=None)) if args.show_per_iter: llm.get_metrics(args.num_iters_warmup, args.only_average, input_len, output_len, args.tensor_parallel_size, args.quantization, llm.dump_info, show_per_iter=args.show_per_iter) latencies = np.array(latencies) percentages = [10, 25, 50, 75, 90, 99] percentiles = np.percentile(latencies, percentages) print(f'Avg latency: {np.mean(latencies)} seconds') for percentage, percentile in zip(percentages, percentiles): print(f'{percentage}% percentile latency: {percentile} seconds') # Output JSON results if specified if args.output_json: results = { "avg_latency": np.mean(latencies), "latencies": latencies.tolist(), "percentiles": dict(zip(percentages, percentiles.tolist())), } with open(args.output_json, "w") as f: json.dump(results, f, indent=4) llm.get_metrics(args.num_iters_warmup, args.only_average, input_len, output_len, args.tensor_parallel_size, args.quantization, llm.dump_info) if __name__ == '__main__': parser = FlexibleArgumentParser( description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--input-case-list', nargs='+', default=['8,32,128'], help="The case list with format [(batch, input_len, output_len), ...].") parser.add_argument('--n', type=int, default=1, help='Number of generated sequences per prompt.') parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--num-iters-warmup', type=int, default=10, help='Number of iterations to run for warmup.') parser.add_argument('--num-iters', type=int, default=30, help='Number of iterations to run.') parser.add_argument( '--profile', action='store_true', help='profile the generation process of a single batch') parser.add_argument( '--profile-result-dir', type=str, default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) parser.add_argument( '--output-json', type=str, default=None, help='Path to save the latency results in JSON format.') parser.add_argument('--only-average', action='store_true', default=False, help=( 'Show all iteration metrics or average metrics.' )) parser.add_argument("--show-per-iter", action='store_true', help='If true, show metrics data per iteration.') parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args)