Fix torch profiler bugs for bench_offline_throughput.py (#6557)
This commit is contained in:
@@ -11,7 +11,9 @@ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -235,8 +237,10 @@ def throughput_test_once(
|
||||
latency = time.perf_counter() - st
|
||||
|
||||
if profile:
|
||||
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
||||
known_files = set(os.listdir(dir))
|
||||
backend.stop_profile()
|
||||
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
||||
monitor_trace_file(known_files, dir)
|
||||
|
||||
if backend_name == "runtime":
|
||||
gen_out = json.loads(gen_out)
|
||||
@@ -260,6 +264,10 @@ def throughput_test_once(
|
||||
measurement_results["total_input_tokens"]
|
||||
+ measurement_results["total_output_tokens"]
|
||||
) / latency
|
||||
|
||||
if inspect.isawaitable(server_info):
|
||||
server_info = asyncio.run(server_info)
|
||||
|
||||
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
||||
"last_gen_throughput"
|
||||
]
|
||||
@@ -267,11 +275,9 @@ def throughput_test_once(
|
||||
return measurement_results
|
||||
|
||||
|
||||
def monitor_trace_file(directory, interval=1):
|
||||
def monitor_trace_file(known_files, directory, interval=1):
|
||||
print(f"Monitoring {directory} for new trace files...")
|
||||
|
||||
known_files = set(os.listdir(directory))
|
||||
|
||||
while True:
|
||||
flag = False
|
||||
time.sleep(interval)
|
||||
|
||||
@@ -85,6 +85,22 @@ class RuntimeEndpoint(BaseBackend):
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def start_profile(self):
|
||||
res = http_request(
|
||||
self.base_url + "/start_profile",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def stop_profile(self):
|
||||
res = http_request(
|
||||
self.base_url + "/stop_profile",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
@@ -374,7 +390,8 @@ class Runtime:
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
proc = multiprocessing.Process(
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
proc = ctx.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer),
|
||||
)
|
||||
@@ -406,6 +423,12 @@ class Runtime:
|
||||
kill_process_tree(self.pid)
|
||||
self.pid = None
|
||||
|
||||
def start_profile(self):
|
||||
self.endpoint.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
self.endpoint.stop_profile()
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
self.endpoint.cache_prefix(prefix)
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
dataclass_to_string_truncated,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
@@ -805,6 +806,8 @@ class TokenizerManager:
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
||||
with_stack = False if with_stack is False or env_with_stack is False else True
|
||||
req = ProfileReq(
|
||||
type=ProfileReqType.START_PROFILE,
|
||||
output_dir=output_dir,
|
||||
|
||||
Reference in New Issue
Block a user