diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fe8fd895b..f0e78f45b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -65,6 +65,7 @@ from sglang.srt.utils import ( is_generation_model, is_multimodal_model, kill_parent_process, + pytorch_profile, set_random_seed, suppress_other_loggers, ) @@ -409,6 +410,10 @@ class Scheduler: new_batch = self.get_new_batch_prefill() if new_batch is not None: # Run a new prefill batch + # replace run_batch with the uncommented line to use pytorch profiler + # result = pytorch_profile( + # "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs) + # ) result = self.run_batch(new_batch) self.process_batch_result(new_batch, result) else: @@ -418,6 +423,13 @@ class Scheduler: batch = self.get_new_batch_decode() if batch: + # replace run_batch with the uncommented line to use pytorch profiler + # result = pytorch_profile( + # "profile_decode_step", + # self.run_batch, + # batch, + # data_size=len(batch.reqs), + # ) result = self.run_batch(batch) self.process_batch_result(batch, result) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6d63d42b0..ff8b4575c 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -17,6 +17,7 @@ limitations under the License. import base64 import ipaddress +import json import logging import os import pickle @@ -37,6 +38,7 @@ import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version from torch import nn +from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( FileCacheManager, default_cache_dir, @@ -642,3 +644,34 @@ def broadcast_pyobj( serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data + + +step_counter = 0 + + +def pytorch_profile(name, func, *args, data_size=-1): + """ + Args: + name (string): the name of recorded function. + func: the function to be profiled. + args: the arguments of the profiled function. + data_size (int): some measurement of the computation complexity. + Usually, it could be the batch size. + """ + global step_counter + os.makedirs("trace", exist_ok=True) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + # on_trace_ready=tensorboard_trace_handler('./log_dir'), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + with record_function(name): + with open(f"trace/size_{step_counter}.json", "w") as f: + json.dump({"size": data_size}, f) + result = func(*args) + prof.export_chrome_trace(f"trace/{name}_{step_counter}.json") + step_counter += 1 + return result diff --git a/scripts/fix_corrupted_json.py b/scripts/fix_corrupted_json.py new file mode 100644 index 000000000..67c2980f5 --- /dev/null +++ b/scripts/fix_corrupted_json.py @@ -0,0 +1,40 @@ +import json +import re +import sys + + +def clean_json_file(input_file, output_file): + try: + # Open the input file with 'replace' option for handling bad characters + with open(input_file, "r", encoding="utf-8", errors="replace") as f: + data = f.read() + + # Replace bad characters (represented by '�' after decoding) with a space + cleaned_data = data.replace("�", " ") + + # Remove control characters (e.g., ASCII control characters like \x00 to \x1F) + # These can cause issues in JSON parsing. + cleaned_data = re.sub(r"[\x00-\x1F]+", " ", cleaned_data) + + # Parse cleaned data as JSON + json_data = json.loads(cleaned_data) + + # Write the cleaned JSON to a new output file + with open(output_file, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False, indent=4) + + print(f"Cleaned JSON file has been saved to {output_file}") + + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + assert len(sys.argv) > 1, "please give the input file path" + if len(sys.argv) == 3: + input_file = sys.argv[1] + output_file = sys.argv[2] + else: + input_file = output_file = sys.argv[1] + + clean_json_file(input_file, output_file) diff --git a/scripts/playground/lora/analyzer.py b/scripts/playground/lora/analyzer.py new file mode 100644 index 000000000..15568fc18 --- /dev/null +++ b/scripts/playground/lora/analyzer.py @@ -0,0 +1,77 @@ +import glob +import json +import os +import re +import sys + +from tqdm import tqdm + +sys.path.append("../../") +from fix_corrupted_json import clean_json_file + +dirpath = "/Users/ying" +output_file_prefix = "analyzed_log" + +time = {} +tot_time = {} +size = {} + +os.system(f"rm {output_file_prefix}*") + +for dirname in glob.glob(os.path.join(dirpath, "trace*")): + print(dirname) + trace_name = dirname.split("/")[-1] + time[trace_name] = {} + size[trace_name] = {} + total_time = 0 + for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))): + step_name = filename.split("/")[-1].split(".")[0] + step_name = "_".join(step_name.split("_")[1:]) + if "prefill" not in filename and "decode" not in filename: + continue + + match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename) + if match: + phase = match.group(1) + step = match.group(2) + else: + raise Exception(f"Cannot parse {filename}") + + try: + with open(filename, "r") as f: + trace = json.load(f) + except: + clean_json_file(filename, filename) + with open(filename, "r") as f: + trace = json.load(f) + + for event in trace["traceEvents"]: + name = event["name"] + if name in ["profile_prefill_step", "profile_decode_step"]: + dur = event["dur"] / 1e3 + time[trace_name][step_name] = dur + break + total_time += dur + + step = int(step_name.split("_")[-1]) + with open(os.path.join(dirname, f"size_{step}.json"), "r") as f: + size_info = json.load(f) + size[trace_name][step_name] = size_info["size"] + + tot_time[trace_name] = total_time + time[trace_name] = dict( + sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1])) + ) + size[trace_name] = dict( + sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1])) + ) + + with open(f"{output_file_prefix}_{trace_name}", "a") as f: + for k, v in time[trace_name].items(): + size_v = size[trace_name][k] + print(f"{k:>15}{v:10.2f}\t{size_v}") + f.write(f"{k:>15}{v:10.2f}\t{size_v}\n") + +with open(f"{output_file_prefix}_total_time", "w") as f: + print(tot_time) + json.dump(tot_time, f)