[Profile] Add pytorch profiler (#1604)
This commit is contained in:
@@ -65,6 +65,7 @@ from sglang.srt.utils import (
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
kill_parent_process,
|
||||
pytorch_profile,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
@@ -409,6 +410,10 @@ class Scheduler:
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
# replace run_batch with the uncommented line to use pytorch profiler
|
||||
# result = pytorch_profile(
|
||||
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
||||
# )
|
||||
result = self.run_batch(new_batch)
|
||||
self.process_batch_result(new_batch, result)
|
||||
else:
|
||||
@@ -418,6 +423,13 @@ class Scheduler:
|
||||
batch = self.get_new_batch_decode()
|
||||
|
||||
if batch:
|
||||
# replace run_batch with the uncommented line to use pytorch profiler
|
||||
# result = pytorch_profile(
|
||||
# "profile_decode_step",
|
||||
# self.run_batch,
|
||||
# batch,
|
||||
# data_size=len(batch.reqs),
|
||||
# )
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@@ -37,6 +38,7 @@ import torch.distributed as dist
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from torch import nn
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
@@ -642,3 +644,34 @@ def broadcast_pyobj(
|
||||
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
|
||||
step_counter = 0
|
||||
|
||||
|
||||
def pytorch_profile(name, func, *args, data_size=-1):
|
||||
"""
|
||||
Args:
|
||||
name (string): the name of recorded function.
|
||||
func: the function to be profiled.
|
||||
args: the arguments of the profiled function.
|
||||
data_size (int): some measurement of the computation complexity.
|
||||
Usually, it could be the batch size.
|
||||
"""
|
||||
global step_counter
|
||||
os.makedirs("trace", exist_ok=True)
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
||||
# on_trace_ready=tensorboard_trace_handler('./log_dir'),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
with record_function(name):
|
||||
with open(f"trace/size_{step_counter}.json", "w") as f:
|
||||
json.dump({"size": data_size}, f)
|
||||
result = func(*args)
|
||||
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
||||
step_counter += 1
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user