Time cost utils (#355)
This commit is contained in:
@@ -11,48 +11,56 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
is_show_cost_time = False
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
|
||||
def mark_cost_time(func_name):
|
||||
def inner_func(func):
|
||||
def time_func(*args, **kwargs):
|
||||
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
ans = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
||||
return ans
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
ans = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
return ans
|
||||
|
||||
return time_func
|
||||
|
||||
return inner_func
|
||||
def enable_show_time_cost():
|
||||
global show_time_cost
|
||||
show_time_cost = True
|
||||
|
||||
|
||||
time_mark = {}
|
||||
class TimeInfo:
|
||||
def __init__(self, name, interval=0.1, color=0, indent=0):
|
||||
self.name = name
|
||||
self.interval = interval
|
||||
self.color = color
|
||||
self.indent = indent
|
||||
|
||||
self.acc_time = 0
|
||||
self.last_acc_time = 0
|
||||
|
||||
def check(self):
|
||||
if self.acc_time - self.last_acc_time > self.interval:
|
||||
self.last_acc_time = self.acc_time
|
||||
return True
|
||||
return False
|
||||
|
||||
def pretty_print(self):
|
||||
print(f"\x1b[{self.color}m", end="")
|
||||
print("-" * self.indent * 2, end="")
|
||||
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
|
||||
|
||||
|
||||
def mark_start(key):
|
||||
def mark_start(name, interval=0.1, color=0, indent=0):
|
||||
global time_infos, show_time_cost
|
||||
if not show_time_cost:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
global time_mark
|
||||
time_mark[key] = time.time()
|
||||
return
|
||||
if time_infos.get(name, None) is None:
|
||||
time_infos[name] = TimeInfo(name, interval, color, indent)
|
||||
time_infos[name].acc_time -= time.time()
|
||||
|
||||
|
||||
def mark_end(key, print_min_cost=0.0):
|
||||
def mark_end(name):
|
||||
global time_infos, show_time_cost
|
||||
if not show_time_cost:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
global time_mark
|
||||
cost_time = (time.time() - time_mark[key]) * 1000
|
||||
if cost_time > print_min_cost:
|
||||
print(f"cost {key}:", cost_time)
|
||||
time_infos[name].acc_time += time.time()
|
||||
if time_infos[name].check():
|
||||
time_infos[name].pretty_print()
|
||||
|
||||
|
||||
def calculate_time(show=False, min_cost_ms=0.0):
|
||||
|
||||
Reference in New Issue
Block a user