Fix torch profiler bugs for bench_offline_throughput.py (#6557)

This commit is contained in:
Yueyang Pan
2025-06-09 20:33:41 +08:00
committed by GitHub
parent 451ffe74d9
commit 98c00a2df1
5 changed files with 49 additions and 5 deletions

View File

@@ -11,7 +11,9 @@ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1
"""
import argparse
import asyncio
import dataclasses
import inspect
import json
import logging
import os
@@ -235,8 +237,10 @@ def throughput_test_once(
latency = time.perf_counter() - st
if profile:
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
known_files = set(os.listdir(dir))
backend.stop_profile()
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
monitor_trace_file(known_files, dir)
if backend_name == "runtime":
gen_out = json.loads(gen_out)
@@ -260,6 +264,10 @@ def throughput_test_once(
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
if inspect.isawaitable(server_info):
server_info = asyncio.run(server_info)
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
"last_gen_throughput"
]
@@ -267,11 +275,9 @@ def throughput_test_once(
return measurement_results
def monitor_trace_file(directory, interval=1):
def monitor_trace_file(known_files, directory, interval=1):
print(f"Monitoring {directory} for new trace files...")
known_files = set(os.listdir(directory))
while True:
flag = False
time.sleep(interval)