Fix torch profiler bugs for bench_offline_throughput.py (#6557)
This commit is contained in:
@@ -52,6 +52,17 @@
|
|||||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Possible PyTorch Bug
|
||||||
|
If in any cases you encounter the following error (for example, using qwen 2.5 VL):
|
||||||
|
```bash
|
||||||
|
RuntimeError: !stack.empty() INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/profiler_python.cpp":983, please report a bug to PyTorch. Python replay stack is empty.
|
||||||
|
```
|
||||||
|
This is likely a PyTorch Bug reported in [Bug: vLLM Profiler](https://github.com/vllm-project/vllm/issues/18240) and [Bug: torch.profiler.profile](https://github.com/pytorch/pytorch/issues/101632). As a workaround, you may disable `with_stack` with an environment variable such as follows:
|
||||||
|
```bash
|
||||||
|
export SGLANG_PROFILE_WITH_STACK=False
|
||||||
|
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
|
||||||
|
```
|
||||||
|
|
||||||
- View Traces
|
- View Traces
|
||||||
|
|
||||||
Trace files can be loaded and visualized from:
|
Trace files can be loaded and visualized from:
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ SGLang supports various environment variables that can be used to configure its
|
|||||||
| Environment Variable | Description | Default Value |
|
| Environment Variable | Description | Default Value |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |
|
| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |
|
||||||
|
| `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` |
|
||||||
|
|
||||||
## Storage & Caching
|
## Storage & Caching
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -235,8 +237,10 @@ def throughput_test_once(
|
|||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
|
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
||||||
|
known_files = set(os.listdir(dir))
|
||||||
backend.stop_profile()
|
backend.stop_profile()
|
||||||
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
monitor_trace_file(known_files, dir)
|
||||||
|
|
||||||
if backend_name == "runtime":
|
if backend_name == "runtime":
|
||||||
gen_out = json.loads(gen_out)
|
gen_out = json.loads(gen_out)
|
||||||
@@ -260,6 +264,10 @@ def throughput_test_once(
|
|||||||
measurement_results["total_input_tokens"]
|
measurement_results["total_input_tokens"]
|
||||||
+ measurement_results["total_output_tokens"]
|
+ measurement_results["total_output_tokens"]
|
||||||
) / latency
|
) / latency
|
||||||
|
|
||||||
|
if inspect.isawaitable(server_info):
|
||||||
|
server_info = asyncio.run(server_info)
|
||||||
|
|
||||||
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
||||||
"last_gen_throughput"
|
"last_gen_throughput"
|
||||||
]
|
]
|
||||||
@@ -267,11 +275,9 @@ def throughput_test_once(
|
|||||||
return measurement_results
|
return measurement_results
|
||||||
|
|
||||||
|
|
||||||
def monitor_trace_file(directory, interval=1):
|
def monitor_trace_file(known_files, directory, interval=1):
|
||||||
print(f"Monitoring {directory} for new trace files...")
|
print(f"Monitoring {directory} for new trace files...")
|
||||||
|
|
||||||
known_files = set(os.listdir(directory))
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
flag = False
|
flag = False
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|||||||
@@ -85,6 +85,22 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
)
|
)
|
||||||
self._assert_success(res)
|
self._assert_success(res)
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
res = http_request(
|
||||||
|
self.base_url + "/start_profile",
|
||||||
|
api_key=self.api_key,
|
||||||
|
verify=self.verify,
|
||||||
|
)
|
||||||
|
self._assert_success(res)
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
res = http_request(
|
||||||
|
self.base_url + "/stop_profile",
|
||||||
|
api_key=self.api_key,
|
||||||
|
verify=self.verify,
|
||||||
|
)
|
||||||
|
self._assert_success(res)
|
||||||
|
|
||||||
def commit_lazy_operations(self, s: StreamExecutor):
|
def commit_lazy_operations(self, s: StreamExecutor):
|
||||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
@@ -374,7 +390,8 @@ class Runtime:
|
|||||||
self.pid = None
|
self.pid = None
|
||||||
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
||||||
|
|
||||||
proc = multiprocessing.Process(
|
ctx = multiprocessing.get_context("spawn")
|
||||||
|
proc = ctx.Process(
|
||||||
target=launch_server,
|
target=launch_server,
|
||||||
args=(self.server_args, pipe_writer),
|
args=(self.server_args, pipe_writer),
|
||||||
)
|
)
|
||||||
@@ -406,6 +423,12 @@ class Runtime:
|
|||||||
kill_process_tree(self.pid)
|
kill_process_tree(self.pid)
|
||||||
self.pid = None
|
self.pid = None
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
self.endpoint.start_profile()
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
self.endpoint.stop_profile()
|
||||||
|
|
||||||
def cache_prefix(self, prefix: str):
|
def cache_prefix(self, prefix: str):
|
||||||
self.endpoint.cache_prefix(prefix)
|
self.endpoint.cache_prefix(prefix)
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
|||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
dataclass_to_string_truncated,
|
dataclass_to_string_truncated,
|
||||||
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
)
|
)
|
||||||
@@ -805,6 +806,8 @@ class TokenizerManager:
|
|||||||
profile_by_stage: bool = False,
|
profile_by_stage: bool = False,
|
||||||
):
|
):
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
|
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
||||||
|
with_stack = False if with_stack is False or env_with_stack is False else True
|
||||||
req = ProfileReq(
|
req = ProfileReq(
|
||||||
type=ProfileReqType.START_PROFILE,
|
type=ProfileReqType.START_PROFILE,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
|
|||||||
Reference in New Issue
Block a user