add profiling to bench_one_batch script (#2821)
This commit is contained in:
@@ -9,7 +9,8 @@ It accepts server arguments (the same as launch_server.py) and benchmark argumen
|
|||||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
||||||
|
## run with profiling:
|
||||||
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
||||||
# Usage (correctness test):
|
# Usage (correctness test):
|
||||||
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||||
|
|
||||||
@@ -77,6 +78,8 @@ class BenchArgs:
|
|||||||
correctness_test: bool = False
|
correctness_test: bool = False
|
||||||
# This is only used for correctness test
|
# This is only used for correctness test
|
||||||
cut_len: int = 4
|
cut_len: int = 4
|
||||||
|
profile: bool = False
|
||||||
|
profile_filename_prefix: str = "profile"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -95,6 +98,19 @@ class BenchArgs:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--correctness-test", action="store_true")
|
parser.add_argument("--correctness-test", action="store_true")
|
||||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Torch Profiler. The endpoint must be launched with "
|
||||||
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile-filename-prefix",
|
||||||
|
type=str,
|
||||||
|
default=BenchArgs.profile_filename_prefix,
|
||||||
|
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
||||||
|
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
@@ -286,7 +302,16 @@ def synchronize(device):
|
|||||||
|
|
||||||
|
|
||||||
def latency_test_run_once(
|
def latency_test_run_once(
|
||||||
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
run_name,
|
||||||
|
model_runner,
|
||||||
|
rank_print,
|
||||||
|
reqs,
|
||||||
|
batch_size,
|
||||||
|
input_len,
|
||||||
|
output_len,
|
||||||
|
device,
|
||||||
|
profile,
|
||||||
|
profile_filename_prefix,
|
||||||
):
|
):
|
||||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||||
if batch_size > max_batch_size:
|
if batch_size > max_batch_size:
|
||||||
@@ -308,6 +333,17 @@ def latency_test_run_once(
|
|||||||
|
|
||||||
tot_latency = 0
|
tot_latency = 0
|
||||||
|
|
||||||
|
profiler = None
|
||||||
|
if profile:
|
||||||
|
profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
profiler.start()
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
synchronize(device)
|
synchronize(device)
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
@@ -338,6 +374,13 @@ def latency_test_run_once(
|
|||||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
profiler.stop()
|
||||||
|
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
|
||||||
|
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
||||||
|
os.makedirs(parent_dir, exist_ok=True)
|
||||||
|
profiler.export_chrome_trace(profile_filename)
|
||||||
|
|
||||||
# Record decode timing from 2nd output
|
# Record decode timing from 2nd output
|
||||||
if output_len > 1:
|
if output_len > 1:
|
||||||
med_decode_latency = np.median(decode_latencies)
|
med_decode_latency = np.median(decode_latencies)
|
||||||
@@ -386,6 +429,8 @@ def latency_test(
|
|||||||
bench_args.input_len[0],
|
bench_args.input_len[0],
|
||||||
8, # shorter decoding to speed up the warmup
|
8, # shorter decoding to speed up the warmup
|
||||||
server_args.device,
|
server_args.device,
|
||||||
|
profile=False,
|
||||||
|
profile_filename_prefix="", # not used
|
||||||
)
|
)
|
||||||
|
|
||||||
rank_print("Benchmark ...")
|
rank_print("Benchmark ...")
|
||||||
@@ -405,6 +450,8 @@ def latency_test(
|
|||||||
il,
|
il,
|
||||||
ol,
|
ol,
|
||||||
server_args.device,
|
server_args.device,
|
||||||
|
bench_args.profile,
|
||||||
|
bench_args.profile_filename_prefix,
|
||||||
)
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
result_list.append(ret)
|
result_list.append(ret)
|
||||||
|
|||||||
Reference in New Issue
Block a user