add profiling to bench_one_batch script (#2821)
This commit is contained in:
@@ -9,7 +9,8 @@ It accepts server arguments (the same as launch_server.py) and benchmark argumen
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
||||
|
||||
## run with profiling:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
@@ -77,6 +78,8 @@ class BenchArgs:
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
profile: bool = False
|
||||
profile_filename_prefix: str = "profile"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -95,6 +98,19 @@ class BenchArgs:
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-filename-prefix",
|
||||
type=str,
|
||||
default=BenchArgs.profile_filename_prefix,
|
||||
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
||||
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -286,7 +302,16 @@ def synchronize(device):
|
||||
|
||||
|
||||
def latency_test_run_once(
|
||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
||||
run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
batch_size,
|
||||
input_len,
|
||||
output_len,
|
||||
device,
|
||||
profile,
|
||||
profile_filename_prefix,
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
@@ -308,6 +333,17 @@ def latency_test_run_once(
|
||||
|
||||
tot_latency = 0
|
||||
|
||||
profiler = None
|
||||
if profile:
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
# Prefill
|
||||
synchronize(device)
|
||||
tic = time.time()
|
||||
@@ -338,6 +374,13 @@ def latency_test_run_once(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
|
||||
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
profiler.export_chrome_trace(profile_filename)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
if output_len > 1:
|
||||
med_decode_latency = np.median(decode_latencies)
|
||||
@@ -386,6 +429,8 @@ def latency_test(
|
||||
bench_args.input_len[0],
|
||||
8, # shorter decoding to speed up the warmup
|
||||
server_args.device,
|
||||
profile=False,
|
||||
profile_filename_prefix="", # not used
|
||||
)
|
||||
|
||||
rank_print("Benchmark ...")
|
||||
@@ -405,6 +450,8 @@ def latency_test(
|
||||
il,
|
||||
ol,
|
||||
server_args.device,
|
||||
bench_args.profile,
|
||||
bench_args.profile_filename_prefix,
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
Reference in New Issue
Block a user