Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)
This commit is contained in:
137
python/sglang/srt/bench_utils.py
Normal file
137
python/sglang/srt/bench_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: str = None,
|
||||
flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False,
|
||||
):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
|
||||
|
||||
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
|
||||
flush_l2_size = int(8e9 // 4)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = (
|
||||
suppress_stdout_stderr
|
||||
if suppress_kineto_output and not using_nsys
|
||||
else nullcontext
|
||||
)
|
||||
with suppress():
|
||||
schedule = (
|
||||
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
if not using_nsys
|
||||
else None
|
||||
)
|
||||
profiler = (
|
||||
torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
)
|
||||
if not using_nsys
|
||||
else nullcontext()
|
||||
)
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(
|
||||
flush_l2_size, dtype=torch.int, device="cuda"
|
||||
).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
profiler.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += (
|
||||
float(time_str.replace(unit, "")) / scale * int(num_str)
|
||||
)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
Reference in New Issue
Block a user