[Profile] Add pytorch profiler (#1604)
This commit is contained in:
@@ -65,6 +65,7 @@ from sglang.srt.utils import (
|
|||||||
is_generation_model,
|
is_generation_model,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
kill_parent_process,
|
kill_parent_process,
|
||||||
|
pytorch_profile,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
@@ -409,6 +410,10 @@ class Scheduler:
|
|||||||
new_batch = self.get_new_batch_prefill()
|
new_batch = self.get_new_batch_prefill()
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
# Run a new prefill batch
|
# Run a new prefill batch
|
||||||
|
# replace run_batch with the uncommented line to use pytorch profiler
|
||||||
|
# result = pytorch_profile(
|
||||||
|
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
||||||
|
# )
|
||||||
result = self.run_batch(new_batch)
|
result = self.run_batch(new_batch)
|
||||||
self.process_batch_result(new_batch, result)
|
self.process_batch_result(new_batch, result)
|
||||||
else:
|
else:
|
||||||
@@ -418,6 +423,13 @@ class Scheduler:
|
|||||||
batch = self.get_new_batch_decode()
|
batch = self.get_new_batch_decode()
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
|
# replace run_batch with the uncommented line to use pytorch profiler
|
||||||
|
# result = pytorch_profile(
|
||||||
|
# "profile_decode_step",
|
||||||
|
# self.run_batch,
|
||||||
|
# batch,
|
||||||
|
# data_size=len(batch.reqs),
|
||||||
|
# )
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -37,6 +38,7 @@ import torch.distributed as dist
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
from triton.runtime.cache import (
|
from triton.runtime.cache import (
|
||||||
FileCacheManager,
|
FileCacheManager,
|
||||||
default_cache_dir,
|
default_cache_dir,
|
||||||
@@ -642,3 +644,34 @@ def broadcast_pyobj(
|
|||||||
serialized_data = bytes(tensor_data.cpu().numpy())
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||||
data = pickle.loads(serialized_data)
|
data = pickle.loads(serialized_data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
step_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_profile(name, func, *args, data_size=-1):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
name (string): the name of recorded function.
|
||||||
|
func: the function to be profiled.
|
||||||
|
args: the arguments of the profiled function.
|
||||||
|
data_size (int): some measurement of the computation complexity.
|
||||||
|
Usually, it could be the batch size.
|
||||||
|
"""
|
||||||
|
global step_counter
|
||||||
|
os.makedirs("trace", exist_ok=True)
|
||||||
|
with profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
||||||
|
# on_trace_ready=tensorboard_trace_handler('./log_dir'),
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
) as prof:
|
||||||
|
with record_function(name):
|
||||||
|
with open(f"trace/size_{step_counter}.json", "w") as f:
|
||||||
|
json.dump({"size": data_size}, f)
|
||||||
|
result = func(*args)
|
||||||
|
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
||||||
|
step_counter += 1
|
||||||
|
return result
|
||||||
|
|||||||
40
scripts/fix_corrupted_json.py
Normal file
40
scripts/fix_corrupted_json.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def clean_json_file(input_file, output_file):
|
||||||
|
try:
|
||||||
|
# Open the input file with 'replace' option for handling bad characters
|
||||||
|
with open(input_file, "r", encoding="utf-8", errors="replace") as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
# Replace bad characters (represented by '<27>' after decoding) with a space
|
||||||
|
cleaned_data = data.replace("<EFBFBD>", " ")
|
||||||
|
|
||||||
|
# Remove control characters (e.g., ASCII control characters like \x00 to \x1F)
|
||||||
|
# These can cause issues in JSON parsing.
|
||||||
|
cleaned_data = re.sub(r"[\x00-\x1F]+", " ", cleaned_data)
|
||||||
|
|
||||||
|
# Parse cleaned data as JSON
|
||||||
|
json_data = json.loads(cleaned_data)
|
||||||
|
|
||||||
|
# Write the cleaned JSON to a new output file
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(f"Cleaned JSON file has been saved to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
assert len(sys.argv) > 1, "please give the input file path"
|
||||||
|
if len(sys.argv) == 3:
|
||||||
|
input_file = sys.argv[1]
|
||||||
|
output_file = sys.argv[2]
|
||||||
|
else:
|
||||||
|
input_file = output_file = sys.argv[1]
|
||||||
|
|
||||||
|
clean_json_file(input_file, output_file)
|
||||||
77
scripts/playground/lora/analyzer.py
Normal file
77
scripts/playground/lora/analyzer.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.append("../../")
|
||||||
|
from fix_corrupted_json import clean_json_file
|
||||||
|
|
||||||
|
dirpath = "/Users/ying"
|
||||||
|
output_file_prefix = "analyzed_log"
|
||||||
|
|
||||||
|
time = {}
|
||||||
|
tot_time = {}
|
||||||
|
size = {}
|
||||||
|
|
||||||
|
os.system(f"rm {output_file_prefix}*")
|
||||||
|
|
||||||
|
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
|
||||||
|
print(dirname)
|
||||||
|
trace_name = dirname.split("/")[-1]
|
||||||
|
time[trace_name] = {}
|
||||||
|
size[trace_name] = {}
|
||||||
|
total_time = 0
|
||||||
|
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
|
||||||
|
step_name = filename.split("/")[-1].split(".")[0]
|
||||||
|
step_name = "_".join(step_name.split("_")[1:])
|
||||||
|
if "prefill" not in filename and "decode" not in filename:
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
|
||||||
|
if match:
|
||||||
|
phase = match.group(1)
|
||||||
|
step = match.group(2)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Cannot parse {filename}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
trace = json.load(f)
|
||||||
|
except:
|
||||||
|
clean_json_file(filename, filename)
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
trace = json.load(f)
|
||||||
|
|
||||||
|
for event in trace["traceEvents"]:
|
||||||
|
name = event["name"]
|
||||||
|
if name in ["profile_prefill_step", "profile_decode_step"]:
|
||||||
|
dur = event["dur"] / 1e3
|
||||||
|
time[trace_name][step_name] = dur
|
||||||
|
break
|
||||||
|
total_time += dur
|
||||||
|
|
||||||
|
step = int(step_name.split("_")[-1])
|
||||||
|
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
|
||||||
|
size_info = json.load(f)
|
||||||
|
size[trace_name][step_name] = size_info["size"]
|
||||||
|
|
||||||
|
tot_time[trace_name] = total_time
|
||||||
|
time[trace_name] = dict(
|
||||||
|
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||||
|
)
|
||||||
|
size[trace_name] = dict(
|
||||||
|
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
|
||||||
|
for k, v in time[trace_name].items():
|
||||||
|
size_v = size[trace_name][k]
|
||||||
|
print(f"{k:>15}{v:10.2f}\t{size_v}")
|
||||||
|
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
|
||||||
|
|
||||||
|
with open(f"{output_file_prefix}_total_time", "w") as f:
|
||||||
|
print(tot_time)
|
||||||
|
json.dump(tot_time, f)
|
||||||
Reference in New Issue
Block a user