[FEATURE] Add Profile Trace Merger for Distributed Traces (#11413)

This commit is contained in:
Neelabh Sinha
2025-10-13 18:20:17 -07:00
committed by GitHub
parent 932e263725
commit aaf7af1b17
10 changed files with 849 additions and 11 deletions

View File

@@ -25,6 +25,7 @@ def _run_profile(
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
merge_profiles: bool = False,
) -> str:
if output_dir is None:
output_dir = PROFILER_DIR
@@ -60,6 +61,7 @@ def _run_profile(
"num_steps": str(num_steps),
"activities": activities,
"profile_by_stage": profile_by_stage,
"merge_profiles": merge_profiles,
}
response = requests.post(url=url + "/start_profile", json=json_data)
@@ -76,10 +78,17 @@ def run_profile(
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
merge_profiles: bool = False,
):
# step based profile will self terminate on num_steps constraints
link = _run_profile(
url, num_steps, activities, output_dir, profile_name, profile_by_stage
url,
num_steps,
activities,
output_dir,
profile_name,
profile_by_stage,
merge_profiles,
)
return link
@@ -145,6 +154,13 @@ if __name__ == "__main__":
default=False,
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
)
parser.add_argument(
"--merge-profiles",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to merge profiles from all ranks into a single trace file",
)
args = parser.parse_args()
activities = []
@@ -163,4 +179,5 @@ if __name__ == "__main__":
args.output_dir,
args.profile_name,
args.profile_by_stage,
args.merge_profiles,
)

View File

@@ -634,6 +634,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
with_stack=obj.with_stack,
record_shapes=obj.record_shapes,
profile_by_stage=obj.profile_by_stage,
merge_profiles=obj.merge_profiles,
)
return Response(
content="Start profiling.\n",

View File

@@ -1232,6 +1232,8 @@ class ProfileReqInput(BaseReq):
profile_by_stage: bool = False
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
# Merge profiles from all ranks into a single trace
merge_profiles: bool = False
class ProfileReqType(Enum):
@@ -1250,6 +1252,8 @@ class ProfileReq(BaseReq):
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
profile_id: Optional[str] = None
# Merge profiles from all ranks into a single trace
merge_profiles: bool = False
@dataclass

View File

@@ -9,6 +9,7 @@ import torch
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_npu
from sglang.srt.utils.profile_merger import ProfileMerger
_is_npu = is_npu()
if _is_npu:
@@ -25,7 +26,6 @@ logger = logging.getLogger(__name__)
class SchedulerProfilerMixin:
def init_profiler(self):
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
@@ -41,6 +41,7 @@ class SchedulerProfilerMixin:
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
self.merge_profiles = False
def init_profile(
self,
@@ -52,6 +53,7 @@ class SchedulerProfilerMixin:
record_shapes: Optional[bool],
profile_by_stage: bool,
profile_id: str,
merge_profiles: bool = False,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
@@ -60,6 +62,7 @@ class SchedulerProfilerMixin:
)
self.profile_by_stage = profile_by_stage
self.merge_profiles = merge_profiles
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
@@ -169,6 +172,38 @@ class SchedulerProfilerMixin:
return ProfileReqOutput(success=True, message="Succeeded")
def _merge_profile_traces(self) -> str:
if not self.merge_profiles:
return ""
if self.tp_rank != 0:
return ""
if getattr(self, "dp_size", 1) > 1 and getattr(self, "dp_rank", 0) != 0:
return ""
if getattr(self, "pp_size", 1) > 1 and getattr(self, "pp_rank", 0) != 0:
return ""
if getattr(self, "moe_ep_size", 1) > 1 and getattr(self, "moe_ep_rank", 0) != 0:
return ""
try:
logger.info("Starting profile merge...")
merger = ProfileMerger(self.torch_profiler_output_dir, self.profile_id)
merged_path = merger.merge_chrome_traces()
summary = merger.get_merge_summary()
merge_message = (
f" Merged trace: {merged_path} "
f"(Events: {summary.get('total_events', '?')}, "
f"Files: {summary.get('total_files', '?')})"
)
logger.info(f"Profile merge completed: {merged_path}")
except Exception as e:
logger.error(f"Failed to merge profiles: {e}", exc_info=True)
return f" Merge failed: {e!s}"
else:
return merge_message
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
@@ -186,14 +221,21 @@ class SchedulerProfilerMixin:
if self.torch_profiler is not None:
self.torch_profiler.stop()
if not _is_npu:
# Build filename with only non-zero ranks to maintain backward compatibility
filename_parts = [self.profile_id, f"TP-{self.tp_rank}"]
# Only add other ranks if parallelism is enabled (size > 1)
if getattr(self, "dp_size", 1) > 1:
filename_parts.append(f"DP-{getattr(self, 'dp_rank', 0)}")
if getattr(self, "pp_size", 1) > 1:
filename_parts.append(f"PP-{getattr(self, 'pp_rank', 0)}")
if getattr(self, "moe_ep_size", 1) > 1:
filename_parts.append(f"EP-{getattr(self, 'moe_ep_rank', 0)}")
filename = "-".join(filename_parts) + stage_suffix + ".trace.json.gz"
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profile_id
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
os.path.join(self.torch_profiler_output_dir, filename)
)
torch.distributed.barrier(self.tp_cpu_group)
@@ -224,15 +266,18 @@ class SchedulerProfilerMixin:
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
merge_message = self._merge_profile_traces()
logger.info(
"Profiling done. Traces are saved to: %s",
"Profiling done. Traces are saved to: %s%s",
self.torch_profiler_output_dir,
merge_message,
)
self.torch_profiler = None
self.profile_in_progress = False
self.profiler_start_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded.")
return ProfileReqOutput(success=True, message=f"Succeeded.{merge_message}")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
@@ -282,6 +327,7 @@ class SchedulerProfilerMixin:
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
recv_req.merge_profiles,
)
else:
self.init_profile(
@@ -293,6 +339,7 @@ class SchedulerProfilerMixin:
recv_req.record_shapes,
recv_req.profile_by_stage,
recv_req.profile_id,
recv_req.merge_profiles,
)
return self.start_profile()
else:

View File

@@ -306,6 +306,7 @@ class TokenizerCommunicatorMixin:
with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
merge_profiles: bool = False,
):
self.auto_create_handle_loop()
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
@@ -320,6 +321,7 @@ class TokenizerCommunicatorMixin:
record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()),
merge_profiles=merge_profiles,
)
return await self._execute_profile(req)

View File

@@ -0,0 +1,199 @@
"""Merge Chrome trace files from multiple ranks (TP, DP, PP, EP) into a single trace."""
import glob
import gzip
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class ProfileMerger:
"""Merge profile traces from all parallelism types: TP, DP, PP, EP."""
def __init__(self, output_dir: str, profile_id: str):
self.output_dir = output_dir
self.profile_id = profile_id
self.merged_trace_path = os.path.join(
output_dir, f"merged-{profile_id}.trace.json.gz"
)
# Rank types in priority order (used for sorting and labeling)
self.rank_types = ["tp", "dp", "pp", "ep"]
# Sort index multipliers: DP (highest) > EP > PP > TP (lowest)
# These ensure proper visual ordering in trace viewer
self.sort_index_multipliers = {
"dp_rank": 100_000_000,
"ep_rank": 1_000_000,
"pp_rank": 10_000,
"tp_rank": 100,
}
# PID threshold for sort_index updates (only update for system PIDs < 1000)
self.pid_sort_index_threshold = 1000
def merge_chrome_traces(self) -> str:
"""Merge Chrome traces from all ranks into a single trace.
Returns:
Path to merged trace file.
Raises:
ValueError: If no trace files found.
"""
trace_files = self._discover_trace_files()
if not trace_files:
raise ValueError(f"No trace files found for profile_id: {self.profile_id}")
logger.info(f"Found {len(trace_files)} trace files to merge")
merged_trace = {"traceEvents": []}
all_device_properties = []
for trace_file in sorted(trace_files, key=self._get_rank_sort_key):
rank_info = self._extract_rank_info(trace_file)
logger.info(f"Processing {trace_file} with rank info: {rank_info}")
output = self._handle_file(trace_file, rank_info)
merged_trace["traceEvents"].extend(output["traceEvents"])
if "deviceProperties" in output:
all_device_properties.extend(output["deviceProperties"])
del output["deviceProperties"]
for key, value in output.items():
if key != "traceEvents" and key not in merged_trace:
merged_trace[key] = value
if all_device_properties:
merged_trace["deviceProperties"] = all_device_properties
with gzip.open(self.merged_trace_path, "wb") as f:
f.write(json.dumps(merged_trace).encode("utf-8"))
logger.info(f"Merged profile saved to: {self.merged_trace_path}")
logger.info(f"Total events merged: {len(merged_trace['traceEvents'])}")
return self.merged_trace_path
def _discover_trace_files(self) -> List[str]:
"""Discover trace files matching profile_id (supports TP/DP/PP/EP formats)."""
patterns = [f"{self.profile_id}*.trace.json.gz"]
trace_files = []
for pattern in patterns:
search_pattern = os.path.join(self.output_dir, pattern)
trace_files.extend(glob.glob(search_pattern))
trace_files = [
f
for f in trace_files
if not f.endswith(f"merged-{self.profile_id}.trace.json.gz")
and not f.endswith("-memory.pickle")
and "TP-" in f
]
trace_files = list(set(trace_files))
return trace_files
def _extract_rank_info(self, filename: str) -> Dict[str, int]:
"""Extract rank info (TP/DP/PP/EP) from filename."""
basename = os.path.basename(filename)
rank_info = {}
for rank_type in self.rank_types:
match = re.search(rf"{rank_type.upper()}-(\d+)", basename)
if match:
rank_info[f"{rank_type}_rank"] = int(match.group(1))
return rank_info
def _create_rank_label(self, rank_info: Dict[str, int]) -> str:
parts = []
for rank_type in self.rank_types:
rank_key = f"{rank_type}_rank"
if rank_key in rank_info:
parts.append(f"{rank_type.upper()}{rank_info[rank_key]:02d}")
return f"[{'-'.join(parts)}]" if parts else "[Unknown]"
def _handle_file(self, path: str, rank_info: Dict[str, int]) -> Dict[str, Any]:
logger.info(f"Processing file: {path}")
try:
with gzip.open(path, "rt", encoding="utf-8") as f:
trace = json.load(f)
output = {
key: value for key, value in trace.items() if key != "traceEvents"
}
output["traceEvents"] = self._process_events(
trace.get("traceEvents", []), rank_info
)
return output
except Exception as e:
logger.error(f"Failed to process trace file {path}: {e}")
return {"traceEvents": []}
def _process_events(
self, events: List[Dict], rank_info: Dict[str, int]
) -> List[Dict]:
"""Process events: update sort_index and add rank labels to PIDs."""
rank_label = self._create_rank_label(rank_info)
for event in events:
if event.get("name") == "process_sort_index":
pid = self._maybe_cast_int(event.get("pid"))
if pid is not None and pid < self.pid_sort_index_threshold:
event["args"]["sort_index"] = self._calculate_sort_index(
rank_info, pid
)
event["pid"] = f"{rank_label} {event['pid']}"
return events
def _calculate_sort_index(self, rank_info: Dict[str, int], pid: int) -> int:
sort_index = pid
for rank_type, multiplier in self.sort_index_multipliers.items():
sort_index += rank_info.get(rank_type, 0) * multiplier
return sort_index
def _get_rank_sort_key(self, path: str) -> Tuple[int, int, int, int]:
rank_info = self._extract_rank_info(path)
return tuple(
rank_info.get(f"{rank_type}_rank", 0)
for rank_type in ["dp", "ep", "pp", "tp"]
)
def _maybe_cast_int(self, x) -> Optional[int]:
try:
return int(x)
except (ValueError, TypeError):
return None
def get_merge_summary(self) -> Dict[str, Any]:
if not os.path.exists(self.merged_trace_path):
return {"error": "Merged trace file not found"}
try:
with gzip.open(self.merged_trace_path, "rt") as f:
merged_data = json.load(f)
trace_files = self._discover_trace_files()
return {
"merged_file": self.merged_trace_path,
"total_events": len(merged_data.get("traceEvents", [])),
"total_files": len(trace_files),
"source_files": [os.path.basename(f) for f in trace_files],
"profile_id": self.profile_id,
"device_properties_count": len(merged_data.get("deviceProperties", [])),
}
except Exception as e:
return {"error": f"Failed to read merged trace: {str(e)}"}