[Profile] Add pytorch profiler (#1604)

This commit is contained in:
Ying Sheng
2024-10-07 14:37:16 -07:00
committed by GitHub
parent ebbc42d989
commit c5325aba75
4 changed files with 162 additions and 0 deletions

View File

@@ -17,6 +17,7 @@ limitations under the License.
import base64
import ipaddress
import json
import logging
import os
import pickle
@@ -37,6 +38,7 @@ import torch.distributed as dist
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
@@ -642,3 +644,34 @@ def broadcast_pyobj(
serialized_data = bytes(tensor_data.cpu().numpy())
data = pickle.loads(serialized_data)
return data
step_counter = 0
def pytorch_profile(name, func, *args, data_size=-1):
"""
Args:
name (string): the name of recorded function.
func: the function to be profiled.
args: the arguments of the profiled function.
data_size (int): some measurement of the computation complexity.
Usually, it could be the batch size.
"""
global step_counter
os.makedirs("trace", exist_ok=True)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
# on_trace_ready=tensorboard_trace_handler('./log_dir'),
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with record_function(name):
with open(f"trace/size_{step_counter}.json", "w") as f:
json.dump({"size": data_size}, f)
result = func(*args)
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
step_counter += 1
return result