[FEATURE] Add Profile Trace Merger for Distributed Traces (#11413)
This commit is contained in:
@@ -25,6 +25,7 @@ def _run_profile(
|
||||
output_dir: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
profile_by_stage: bool = False,
|
||||
merge_profiles: bool = False,
|
||||
) -> str:
|
||||
if output_dir is None:
|
||||
output_dir = PROFILER_DIR
|
||||
@@ -60,6 +61,7 @@ def _run_profile(
|
||||
"num_steps": str(num_steps),
|
||||
"activities": activities,
|
||||
"profile_by_stage": profile_by_stage,
|
||||
"merge_profiles": merge_profiles,
|
||||
}
|
||||
|
||||
response = requests.post(url=url + "/start_profile", json=json_data)
|
||||
@@ -76,10 +78,17 @@ def run_profile(
|
||||
output_dir: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
profile_by_stage: bool = False,
|
||||
merge_profiles: bool = False,
|
||||
):
|
||||
# step based profile will self terminate on num_steps constraints
|
||||
link = _run_profile(
|
||||
url, num_steps, activities, output_dir, profile_name, profile_by_stage
|
||||
url,
|
||||
num_steps,
|
||||
activities,
|
||||
output_dir,
|
||||
profile_name,
|
||||
profile_by_stage,
|
||||
merge_profiles,
|
||||
)
|
||||
return link
|
||||
|
||||
@@ -145,6 +154,13 @@ if __name__ == "__main__":
|
||||
default=False,
|
||||
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merge-profiles",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to merge profiles from all ranks into a single trace file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
activities = []
|
||||
@@ -163,4 +179,5 @@ if __name__ == "__main__":
|
||||
args.output_dir,
|
||||
args.profile_name,
|
||||
args.profile_by_stage,
|
||||
args.merge_profiles,
|
||||
)
|
||||
|
||||
@@ -634,6 +634,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
||||
with_stack=obj.with_stack,
|
||||
record_shapes=obj.record_shapes,
|
||||
profile_by_stage=obj.profile_by_stage,
|
||||
merge_profiles=obj.merge_profiles,
|
||||
)
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
|
||||
@@ -1232,6 +1232,8 @@ class ProfileReqInput(BaseReq):
|
||||
profile_by_stage: bool = False
|
||||
with_stack: Optional[bool] = None
|
||||
record_shapes: Optional[bool] = None
|
||||
# Merge profiles from all ranks into a single trace
|
||||
merge_profiles: bool = False
|
||||
|
||||
|
||||
class ProfileReqType(Enum):
|
||||
@@ -1250,6 +1252,8 @@ class ProfileReq(BaseReq):
|
||||
with_stack: Optional[bool] = None
|
||||
record_shapes: Optional[bool] = None
|
||||
profile_id: Optional[str] = None
|
||||
# Merge profiles from all ranks into a single trace
|
||||
merge_profiles: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import is_npu
|
||||
from sglang.srt.utils.profile_merger import ProfileMerger
|
||||
|
||||
_is_npu = is_npu()
|
||||
if _is_npu:
|
||||
@@ -25,7 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerProfilerMixin:
|
||||
|
||||
def init_profiler(self):
|
||||
self.torch_profiler = None
|
||||
self.torch_profiler_output_dir: Optional[str] = None
|
||||
@@ -41,6 +41,7 @@ class SchedulerProfilerMixin:
|
||||
self.profile_steps: Optional[int] = None
|
||||
self.profile_in_progress: bool = False
|
||||
self.rpd_profiler = None
|
||||
self.merge_profiles = False
|
||||
|
||||
def init_profile(
|
||||
self,
|
||||
@@ -52,6 +53,7 @@ class SchedulerProfilerMixin:
|
||||
record_shapes: Optional[bool],
|
||||
profile_by_stage: bool,
|
||||
profile_id: str,
|
||||
merge_profiles: bool = False,
|
||||
) -> ProfileReqOutput:
|
||||
if self.profile_in_progress:
|
||||
return ProfileReqOutput(
|
||||
@@ -60,6 +62,7 @@ class SchedulerProfilerMixin:
|
||||
)
|
||||
|
||||
self.profile_by_stage = profile_by_stage
|
||||
self.merge_profiles = merge_profiles
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
||||
@@ -169,6 +172,38 @@ class SchedulerProfilerMixin:
|
||||
|
||||
return ProfileReqOutput(success=True, message="Succeeded")
|
||||
|
||||
def _merge_profile_traces(self) -> str:
|
||||
if not self.merge_profiles:
|
||||
return ""
|
||||
|
||||
if self.tp_rank != 0:
|
||||
return ""
|
||||
if getattr(self, "dp_size", 1) > 1 and getattr(self, "dp_rank", 0) != 0:
|
||||
return ""
|
||||
if getattr(self, "pp_size", 1) > 1 and getattr(self, "pp_rank", 0) != 0:
|
||||
return ""
|
||||
if getattr(self, "moe_ep_size", 1) > 1 and getattr(self, "moe_ep_rank", 0) != 0:
|
||||
return ""
|
||||
|
||||
try:
|
||||
logger.info("Starting profile merge...")
|
||||
merger = ProfileMerger(self.torch_profiler_output_dir, self.profile_id)
|
||||
merged_path = merger.merge_chrome_traces()
|
||||
|
||||
summary = merger.get_merge_summary()
|
||||
merge_message = (
|
||||
f" Merged trace: {merged_path} "
|
||||
f"(Events: {summary.get('total_events', '?')}, "
|
||||
f"Files: {summary.get('total_files', '?')})"
|
||||
)
|
||||
|
||||
logger.info(f"Profile merge completed: {merged_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to merge profiles: {e}", exc_info=True)
|
||||
return f" Merge failed: {e!s}"
|
||||
else:
|
||||
return merge_message
|
||||
|
||||
def stop_profile(
|
||||
self, stage: Optional[ForwardMode] = None
|
||||
) -> ProfileReqOutput | None:
|
||||
@@ -186,14 +221,21 @@ class SchedulerProfilerMixin:
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.stop()
|
||||
if not _is_npu:
|
||||
# Build filename with only non-zero ranks to maintain backward compatibility
|
||||
filename_parts = [self.profile_id, f"TP-{self.tp_rank}"]
|
||||
|
||||
# Only add other ranks if parallelism is enabled (size > 1)
|
||||
if getattr(self, "dp_size", 1) > 1:
|
||||
filename_parts.append(f"DP-{getattr(self, 'dp_rank', 0)}")
|
||||
if getattr(self, "pp_size", 1) > 1:
|
||||
filename_parts.append(f"PP-{getattr(self, 'pp_rank', 0)}")
|
||||
if getattr(self, "moe_ep_size", 1) > 1:
|
||||
filename_parts.append(f"EP-{getattr(self, 'moe_ep_rank', 0)}")
|
||||
|
||||
filename = "-".join(filename_parts) + stage_suffix + ".trace.json.gz"
|
||||
|
||||
self.torch_profiler.export_chrome_trace(
|
||||
os.path.join(
|
||||
self.torch_profiler_output_dir,
|
||||
self.profile_id
|
||||
+ f"-TP-{self.tp_rank}"
|
||||
+ stage_suffix
|
||||
+ ".trace.json.gz",
|
||||
)
|
||||
os.path.join(self.torch_profiler_output_dir, filename)
|
||||
)
|
||||
torch.distributed.barrier(self.tp_cpu_group)
|
||||
|
||||
@@ -224,15 +266,18 @@ class SchedulerProfilerMixin:
|
||||
if "CUDA_PROFILER" in self.profiler_activities:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
|
||||
merge_message = self._merge_profile_traces()
|
||||
|
||||
logger.info(
|
||||
"Profiling done. Traces are saved to: %s",
|
||||
"Profiling done. Traces are saved to: %s%s",
|
||||
self.torch_profiler_output_dir,
|
||||
merge_message,
|
||||
)
|
||||
self.torch_profiler = None
|
||||
self.profile_in_progress = False
|
||||
self.profiler_start_forward_ct = None
|
||||
|
||||
return ProfileReqOutput(success=True, message="Succeeded.")
|
||||
return ProfileReqOutput(success=True, message=f"Succeeded.{merge_message}")
|
||||
|
||||
def _profile_batch_predicate(self, batch):
|
||||
if self.profile_by_stage:
|
||||
@@ -282,6 +327,7 @@ class SchedulerProfilerMixin:
|
||||
recv_req.record_shapes,
|
||||
recv_req.profile_by_stage,
|
||||
recv_req.profile_id,
|
||||
recv_req.merge_profiles,
|
||||
)
|
||||
else:
|
||||
self.init_profile(
|
||||
@@ -293,6 +339,7 @@ class SchedulerProfilerMixin:
|
||||
recv_req.record_shapes,
|
||||
recv_req.profile_by_stage,
|
||||
recv_req.profile_id,
|
||||
recv_req.merge_profiles,
|
||||
)
|
||||
return self.start_profile()
|
||||
else:
|
||||
|
||||
@@ -306,6 +306,7 @@ class TokenizerCommunicatorMixin:
|
||||
with_stack: Optional[bool] = None,
|
||||
record_shapes: Optional[bool] = None,
|
||||
profile_by_stage: bool = False,
|
||||
merge_profiles: bool = False,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
||||
@@ -320,6 +321,7 @@ class TokenizerCommunicatorMixin:
|
||||
record_shapes=record_shapes,
|
||||
profile_by_stage=profile_by_stage,
|
||||
profile_id=str(time.time()),
|
||||
merge_profiles=merge_profiles,
|
||||
)
|
||||
return await self._execute_profile(req)
|
||||
|
||||
|
||||
199
python/sglang/srt/utils/profile_merger.py
Normal file
199
python/sglang/srt/utils/profile_merger.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Merge Chrome trace files from multiple ranks (TP, DP, PP, EP) into a single trace."""
|
||||
|
||||
import glob
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileMerger:
|
||||
"""Merge profile traces from all parallelism types: TP, DP, PP, EP."""
|
||||
|
||||
def __init__(self, output_dir: str, profile_id: str):
|
||||
self.output_dir = output_dir
|
||||
self.profile_id = profile_id
|
||||
self.merged_trace_path = os.path.join(
|
||||
output_dir, f"merged-{profile_id}.trace.json.gz"
|
||||
)
|
||||
|
||||
# Rank types in priority order (used for sorting and labeling)
|
||||
self.rank_types = ["tp", "dp", "pp", "ep"]
|
||||
|
||||
# Sort index multipliers: DP (highest) > EP > PP > TP (lowest)
|
||||
# These ensure proper visual ordering in trace viewer
|
||||
self.sort_index_multipliers = {
|
||||
"dp_rank": 100_000_000,
|
||||
"ep_rank": 1_000_000,
|
||||
"pp_rank": 10_000,
|
||||
"tp_rank": 100,
|
||||
}
|
||||
|
||||
# PID threshold for sort_index updates (only update for system PIDs < 1000)
|
||||
self.pid_sort_index_threshold = 1000
|
||||
|
||||
def merge_chrome_traces(self) -> str:
|
||||
"""Merge Chrome traces from all ranks into a single trace.
|
||||
|
||||
Returns:
|
||||
Path to merged trace file.
|
||||
|
||||
Raises:
|
||||
ValueError: If no trace files found.
|
||||
"""
|
||||
trace_files = self._discover_trace_files()
|
||||
if not trace_files:
|
||||
raise ValueError(f"No trace files found for profile_id: {self.profile_id}")
|
||||
|
||||
logger.info(f"Found {len(trace_files)} trace files to merge")
|
||||
|
||||
merged_trace = {"traceEvents": []}
|
||||
all_device_properties = []
|
||||
|
||||
for trace_file in sorted(trace_files, key=self._get_rank_sort_key):
|
||||
rank_info = self._extract_rank_info(trace_file)
|
||||
logger.info(f"Processing {trace_file} with rank info: {rank_info}")
|
||||
|
||||
output = self._handle_file(trace_file, rank_info)
|
||||
|
||||
merged_trace["traceEvents"].extend(output["traceEvents"])
|
||||
|
||||
if "deviceProperties" in output:
|
||||
all_device_properties.extend(output["deviceProperties"])
|
||||
del output["deviceProperties"]
|
||||
|
||||
for key, value in output.items():
|
||||
if key != "traceEvents" and key not in merged_trace:
|
||||
merged_trace[key] = value
|
||||
|
||||
if all_device_properties:
|
||||
merged_trace["deviceProperties"] = all_device_properties
|
||||
|
||||
with gzip.open(self.merged_trace_path, "wb") as f:
|
||||
f.write(json.dumps(merged_trace).encode("utf-8"))
|
||||
|
||||
logger.info(f"Merged profile saved to: {self.merged_trace_path}")
|
||||
logger.info(f"Total events merged: {len(merged_trace['traceEvents'])}")
|
||||
|
||||
return self.merged_trace_path
|
||||
|
||||
def _discover_trace_files(self) -> List[str]:
|
||||
"""Discover trace files matching profile_id (supports TP/DP/PP/EP formats)."""
|
||||
patterns = [f"{self.profile_id}*.trace.json.gz"]
|
||||
|
||||
trace_files = []
|
||||
for pattern in patterns:
|
||||
search_pattern = os.path.join(self.output_dir, pattern)
|
||||
trace_files.extend(glob.glob(search_pattern))
|
||||
|
||||
trace_files = [
|
||||
f
|
||||
for f in trace_files
|
||||
if not f.endswith(f"merged-{self.profile_id}.trace.json.gz")
|
||||
and not f.endswith("-memory.pickle")
|
||||
and "TP-" in f
|
||||
]
|
||||
trace_files = list(set(trace_files))
|
||||
return trace_files
|
||||
|
||||
def _extract_rank_info(self, filename: str) -> Dict[str, int]:
|
||||
"""Extract rank info (TP/DP/PP/EP) from filename."""
|
||||
basename = os.path.basename(filename)
|
||||
rank_info = {}
|
||||
|
||||
for rank_type in self.rank_types:
|
||||
match = re.search(rf"{rank_type.upper()}-(\d+)", basename)
|
||||
if match:
|
||||
rank_info[f"{rank_type}_rank"] = int(match.group(1))
|
||||
|
||||
return rank_info
|
||||
|
||||
def _create_rank_label(self, rank_info: Dict[str, int]) -> str:
|
||||
parts = []
|
||||
for rank_type in self.rank_types:
|
||||
rank_key = f"{rank_type}_rank"
|
||||
if rank_key in rank_info:
|
||||
parts.append(f"{rank_type.upper()}{rank_info[rank_key]:02d}")
|
||||
|
||||
return f"[{'-'.join(parts)}]" if parts else "[Unknown]"
|
||||
|
||||
def _handle_file(self, path: str, rank_info: Dict[str, int]) -> Dict[str, Any]:
|
||||
logger.info(f"Processing file: {path}")
|
||||
|
||||
try:
|
||||
with gzip.open(path, "rt", encoding="utf-8") as f:
|
||||
trace = json.load(f)
|
||||
|
||||
output = {
|
||||
key: value for key, value in trace.items() if key != "traceEvents"
|
||||
}
|
||||
output["traceEvents"] = self._process_events(
|
||||
trace.get("traceEvents", []), rank_info
|
||||
)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process trace file {path}: {e}")
|
||||
return {"traceEvents": []}
|
||||
|
||||
def _process_events(
|
||||
self, events: List[Dict], rank_info: Dict[str, int]
|
||||
) -> List[Dict]:
|
||||
"""Process events: update sort_index and add rank labels to PIDs."""
|
||||
rank_label = self._create_rank_label(rank_info)
|
||||
|
||||
for event in events:
|
||||
if event.get("name") == "process_sort_index":
|
||||
pid = self._maybe_cast_int(event.get("pid"))
|
||||
if pid is not None and pid < self.pid_sort_index_threshold:
|
||||
event["args"]["sort_index"] = self._calculate_sort_index(
|
||||
rank_info, pid
|
||||
)
|
||||
|
||||
event["pid"] = f"{rank_label} {event['pid']}"
|
||||
|
||||
return events
|
||||
|
||||
def _calculate_sort_index(self, rank_info: Dict[str, int], pid: int) -> int:
|
||||
sort_index = pid
|
||||
for rank_type, multiplier in self.sort_index_multipliers.items():
|
||||
sort_index += rank_info.get(rank_type, 0) * multiplier
|
||||
return sort_index
|
||||
|
||||
def _get_rank_sort_key(self, path: str) -> Tuple[int, int, int, int]:
|
||||
rank_info = self._extract_rank_info(path)
|
||||
return tuple(
|
||||
rank_info.get(f"{rank_type}_rank", 0)
|
||||
for rank_type in ["dp", "ep", "pp", "tp"]
|
||||
)
|
||||
|
||||
def _maybe_cast_int(self, x) -> Optional[int]:
|
||||
try:
|
||||
return int(x)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def get_merge_summary(self) -> Dict[str, Any]:
|
||||
if not os.path.exists(self.merged_trace_path):
|
||||
return {"error": "Merged trace file not found"}
|
||||
|
||||
try:
|
||||
with gzip.open(self.merged_trace_path, "rt") as f:
|
||||
merged_data = json.load(f)
|
||||
|
||||
trace_files = self._discover_trace_files()
|
||||
|
||||
return {
|
||||
"merged_file": self.merged_trace_path,
|
||||
"total_events": len(merged_data.get("traceEvents", [])),
|
||||
"total_files": len(trace_files),
|
||||
"source_files": [os.path.basename(f) for f in trace_files],
|
||||
"profile_id": self.profile_id,
|
||||
"device_properties_count": len(merged_data.get("deviceProperties", [])),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to read merged trace: {str(e)}"}
|
||||
Reference in New Issue
Block a user