[FEATURE] Add Profile Trace Merger for Distributed Traces (#11413)

This commit is contained in:
Neelabh Sinha
2025-10-13 18:20:17 -07:00
committed by GitHub
parent 932e263725
commit aaf7af1b17
10 changed files with 849 additions and 11 deletions

View File

@@ -115,6 +115,8 @@ suites = {
TestFile("test_srt_engine.py", 261),
TestFile("test_standalone_speculative_decoding.py", 250),
TestFile("test_start_profile.py", 60),
TestFile("test_profile_merger.py", 60),
TestFile("test_profile_merger_http_api.py", 15),
TestFile("test_swa_unittest.py", 1),
TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 172),

View File

@@ -0,0 +1,363 @@
"""
Unit tests for the ProfileMerger implementation.
Usage:
python test_profile_merger.py
python -m unittest test_profile_merger.py -v
"""
import gzip
import json
import os
import shutil
import tempfile
import unittest
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqInput, ProfileReqType
from sglang.srt.utils.profile_merger import ProfileMerger
class TestProfileMerger(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.profile_id = "test_profile_123"
self.merger = ProfileMerger(self.temp_dir, self.profile_id)
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_rank_extraction_and_labeling(self):
# Test TP-only
filename = f"{self.profile_id}-TP-0.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {"tp_rank": 0})
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[TP00]")
# Test all parallelism types
filename = f"{self.profile_id}-TP-1-DP-2-PP-3-EP-4.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(
rank_info, {"tp_rank": 1, "dp_rank": 2, "pp_rank": 3, "ep_rank": 4}
)
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[TP01-DP02-PP03-EP04]")
# Test partial ranks
filename = f"{self.profile_id}-TP-0-DP-1.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {"tp_rank": 0, "dp_rank": 1})
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[TP00-DP01]")
# Test no ranks
filename = f"{self.profile_id}.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {})
label = self.merger._create_rank_label(rank_info)
self.assertEqual(label, "[Unknown]")
def test_sort_index_calculation(self):
# Single rank
rank_info = {"tp_rank": 0}
sort_idx = self.merger._calculate_sort_index(rank_info, 83)
self.assertEqual(sort_idx, 83)
# Multiple ranks
rank_info = {"tp_rank": 1, "dp_rank": 2, "pp_rank": 3, "ep_rank": 4}
sort_idx = self.merger._calculate_sort_index(rank_info, 83)
self.assertNotEqual(sort_idx, 83)
self.assertGreater(sort_idx, 1000000)
# Empty ranks
rank_info = {}
sort_idx = self.merger._calculate_sort_index(rank_info, 83)
self.assertEqual(sort_idx, 83)
def test_rank_sort_key(self):
# Full ranks: TP-1, DP-2, PP-3, EP-4 → sorted as (DP, EP, PP, TP)
filename = f"{self.profile_id}-TP-1-DP-2-PP-3-EP-4.trace.json.gz"
sort_key = self.merger._get_rank_sort_key(filename)
self.assertEqual(sort_key, (2, 4, 3, 1))
# Missing ranks: only TP-1 → sorted as (DP=0, EP=0, PP=0, TP=1)
filename = f"{self.profile_id}-TP-1.trace.json.gz"
sort_key = self.merger._get_rank_sort_key(filename)
self.assertEqual(sort_key, (0, 0, 0, 1))
def test_discover_trace_files(self):
# Create mock trace files
trace_files = [
f"{self.profile_id}-TP-0.trace.json.gz", # Old format
f"{self.profile_id}-TP-1.trace.json.gz", # Old format
f"{self.profile_id}-TP-0-DP-1.trace.json.gz", # New format
]
for filename in trace_files:
filepath = os.path.join(self.temp_dir, filename)
with gzip.open(filepath, "wt") as f:
json.dump({"traceEvents": []}, f)
discovered = self.merger._discover_trace_files()
self.assertEqual(len(discovered), 3)
# Check that all expected files are discovered
discovered_basenames = {os.path.basename(f) for f in discovered}
expected_basenames = {
f"{self.profile_id}-TP-0.trace.json.gz",
f"{self.profile_id}-TP-1.trace.json.gz",
f"{self.profile_id}-TP-0-DP-1.trace.json.gz",
}
self.assertEqual(discovered_basenames, expected_basenames)
# Test no matches
empty_merger = ProfileMerger(self.temp_dir, "nonexistent")
discovered = empty_merger._discover_trace_files()
self.assertEqual(len(discovered), 0)
def test_merge_chrome_traces(self):
# Create multiple trace files in random order
trace_files = [
{
"filename": f"{self.profile_id}-TP-1-DP-1.trace.json.gz",
"events": [
{"ph": "X", "name": "op1", "pid": 83, "ts": 1000.0, "dur": 10.0}
],
},
{
"filename": f"{self.profile_id}-TP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op2", "pid": 84, "ts": 2000.0, "dur": 15.0}
],
},
{
"filename": f"{self.profile_id}-TP-0-DP-1.trace.json.gz",
"events": [
{"ph": "X", "name": "op3", "pid": 85, "ts": 3000.0, "dur": 20.0}
],
},
]
for trace_data in trace_files:
filepath = os.path.join(self.temp_dir, trace_data["filename"])
trace_content = {
"schemaVersion": 1,
"deviceProperties": [{"device_id": 0, "name": "GPU-0"}],
"traceEvents": trace_data["events"],
}
with gzip.open(filepath, "wt") as f:
json.dump(trace_content, f)
# Test file ordering by capturing log messages
import logging
logger = logging.getLogger("sglang.srt.utils.profile_merger")
with self.assertLogs(logger, level="INFO") as log_capture:
merged_path = self.merger.merge_chrome_traces()
# Verify files were processed in rank order
log_messages = [
record.getMessage()
for record in log_capture.records
if "Processing file:" in record.getMessage()
]
self.assertIn("TP-0.trace.json.gz", log_messages[0]) # (0,0,0,0) comes first
self.assertIn(
"TP-0-DP-1.trace.json.gz", log_messages[1]
) # (0,1,0,0) comes second
self.assertIn(
"TP-1-DP-1.trace.json.gz", log_messages[2]
) # (1,1,0,0) comes last
# Verify merged content
self.assertTrue(os.path.exists(merged_path))
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertEqual(len(merged_data["traceEvents"]), 3)
self.assertEqual(len(merged_data["deviceProperties"]), 3)
# Check rank labels in events
events = merged_data["traceEvents"]
pids = [event["pid"] for event in events]
self.assertIn("[TP00] 84", pids)
self.assertIn("[TP00-DP01] 85", pids)
self.assertIn("[TP01-DP01] 83", pids)
# Test merge summary
summary = self.merger.get_merge_summary()
self.assertEqual(summary["total_files"], 3)
self.assertEqual(summary["total_events"], 3)
self.assertEqual(summary["profile_id"], self.profile_id)
# Test no files error
empty_merger = ProfileMerger(self.temp_dir, "nonexistent")
with self.assertRaises(ValueError):
empty_merger.merge_chrome_traces()
class TestProfileMergerIntegration(unittest.TestCase):
def test_data_structures_merge_profiles(self):
# Test ProfileReqInput
req_input = ProfileReqInput()
self.assertFalse(req_input.merge_profiles)
req_input = ProfileReqInput(merge_profiles=True)
self.assertTrue(req_input.merge_profiles)
# Test ProfileReq
req = ProfileReq(type=ProfileReqType.START_PROFILE)
self.assertFalse(req.merge_profiles)
req = ProfileReq(type=ProfileReqType.START_PROFILE, merge_profiles=True)
self.assertTrue(req.merge_profiles)
def test_integration_parameters(self):
import inspect
# Test TokenizerManager
from sglang.srt.managers.tokenizer_communicator_mixin import (
TokenizerCommunicatorMixin,
)
sig = inspect.signature(TokenizerCommunicatorMixin.start_profile)
self.assertIn("merge_profiles", sig.parameters)
# Test SchedulerProfilerMixin
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
sig = inspect.signature(SchedulerProfilerMixin.init_profile)
self.assertIn("merge_profiles", sig.parameters)
# Test CLI profiler
from sglang.profiler import run_profile
sig = inspect.signature(run_profile)
self.assertIn("merge_profiles", sig.parameters)
class TestProfileMergerEdgeCases(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.profile_id = "test_edge_cases"
self.merger = ProfileMerger(self.temp_dir, self.profile_id)
def tearDown(self):
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_error_handling_and_edge_cases(self):
# Test malformed trace file
filename = f"{self.profile_id}-TP-0.trace.json.gz"
filepath = os.path.join(self.temp_dir, filename)
with gzip.open(filepath, "wt") as f:
f.write("invalid json content")
merged_path = self.merger.merge_chrome_traces()
self.assertTrue(os.path.exists(merged_path))
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertEqual(len(merged_data["traceEvents"]), 0)
# Test empty trace file
with gzip.open(filepath, "wt") as f:
json.dump({}, f)
merged_path = self.merger.merge_chrome_traces()
self.assertTrue(os.path.exists(merged_path))
# Test missing device properties
trace_data = {
"schemaVersion": 1,
"traceEvents": [
{"ph": "X", "name": "test", "pid": 83, "ts": 1000.0, "dur": 10.0}
],
}
with gzip.open(filepath, "wt") as f:
json.dump(trace_data, f)
merged_path = self.merger.merge_chrome_traces()
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertNotIn("deviceProperties", merged_data)
def test_missing_ranks_and_none_handling(self):
# Test rank extraction with missing ranks
filename = f"{self.profile_id}-TP-0.trace.json.gz"
rank_info = self.merger._extract_rank_info(filename)
self.assertEqual(rank_info, {"tp_rank": 0})
# Test rank label creation with missing ranks
label = self.merger._create_rank_label({"tp_rank": 0})
self.assertEqual(label, "[TP00]")
label = self.merger._create_rank_label({})
self.assertEqual(label, "[Unknown]")
# Test sort index calculation
sort_idx = self.merger._calculate_sort_index({"tp_rank": 0}, 83)
self.assertGreater(sort_idx, 0)
sort_idx = self.merger._calculate_sort_index({}, 83)
self.assertEqual(sort_idx, 83)
# Test sort key generation
sort_key = self.merger._get_rank_sort_key(filename)
self.assertEqual(sort_key, (0, 0, 0, 0))
# Test _maybe_cast_int with various inputs
self.assertIsNone(self.merger._maybe_cast_int(None))
self.assertIsNone(self.merger._maybe_cast_int("invalid"))
self.assertEqual(self.merger._maybe_cast_int("123"), 123)
self.assertEqual(self.merger._maybe_cast_int(456), 456)
def test_mixed_rank_scenarios(self):
trace_scenarios = [
{
"filename": f"{self.profile_id}-TP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op1", "pid": 83, "ts": 1000.0, "dur": 10.0}
],
},
{
"filename": f"{self.profile_id}-TP-1-DP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op2", "pid": 84, "ts": 2000.0, "dur": 15.0}
],
},
{
"filename": f"{self.profile_id}-TP-0-DP-1-PP-0.trace.json.gz",
"events": [
{"ph": "X", "name": "op3", "pid": 85, "ts": 3000.0, "dur": 20.0}
],
},
]
for scenario in trace_scenarios:
filepath = os.path.join(self.temp_dir, scenario["filename"])
trace_data = {
"schemaVersion": 1,
"deviceProperties": [{"device_id": 0, "name": "GPU-0"}],
"traceEvents": scenario["events"],
}
with gzip.open(filepath, "wt") as f:
json.dump(trace_data, f)
merged_path = self.merger.merge_chrome_traces()
self.assertTrue(os.path.exists(merged_path))
with gzip.open(merged_path, "rt") as f:
merged_data = json.load(f)
self.assertEqual(len(merged_data["traceEvents"]), 3)
events = merged_data["traceEvents"]
pids = [event["pid"] for event in events]
self.assertIn("[TP00] 83", pids)
self.assertIn("[TP01-DP00] 84", pids)
self.assertIn("[TP00-DP01-PP00] 85", pids)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,162 @@
import json
import unittest
from sglang.srt.managers.io_struct import ProfileReqInput
class TestProfileMergerHTTPAPI(unittest.TestCase):
def test_profile_req_input_merge_profiles_json_serialization(self):
# Test with merge_profiles=True
req_input = ProfileReqInput(
output_dir="/tmp/test",
num_steps=5,
activities=["CPU", "GPU"],
profile_by_stage=True,
merge_profiles=True,
)
# Convert to dict (as would happen in HTTP request)
req_dict = {
"output_dir": req_input.output_dir,
"num_steps": req_input.num_steps,
"activities": req_input.activities,
"profile_by_stage": req_input.profile_by_stage,
"merge_profiles": req_input.merge_profiles,
}
# Test JSON serialization
json_str = json.dumps(req_dict)
parsed_data = json.loads(json_str)
self.assertTrue(parsed_data["merge_profiles"])
self.assertEqual(parsed_data["output_dir"], "/tmp/test")
self.assertEqual(parsed_data["num_steps"], 5)
self.assertEqual(parsed_data["activities"], ["CPU", "GPU"])
self.assertTrue(parsed_data["profile_by_stage"])
def test_profile_req_input_merge_profiles_json_deserialization(self):
# Test JSON data as would come from HTTP request
json_data = {
"output_dir": "/tmp/test",
"num_steps": 10,
"activities": ["CPU", "GPU", "MEM"],
"profile_by_stage": False,
"merge_profiles": True,
}
# Create ProfileReqInput from dict (as HTTP server would do)
req_input = ProfileReqInput(**json_data)
self.assertTrue(req_input.merge_profiles)
self.assertEqual(req_input.output_dir, "/tmp/test")
self.assertEqual(req_input.num_steps, 10)
self.assertEqual(req_input.activities, ["CPU", "GPU", "MEM"])
self.assertFalse(req_input.profile_by_stage)
def test_profile_req_input_merge_profiles_default_value(self):
# Test with minimal data
json_data = {"output_dir": "/tmp/test"}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles)
def test_profile_req_input_merge_profiles_explicit_false(self):
json_data = {"output_dir": "/tmp/test", "merge_profiles": False}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles)
def test_http_api_parameter_flow(self):
# Simulate HTTP request data
request_data = {
"output_dir": "/tmp/test",
"num_steps": 5,
"activities": ["CPU", "GPU"],
"profile_by_stage": True,
"merge_profiles": True,
}
# Create ProfileReqInput as HTTP server would
obj = ProfileReqInput(**request_data)
# Verify the parameter is set correctly
self.assertTrue(obj.merge_profiles)
self.assertEqual(obj.output_dir, "/tmp/test")
self.assertEqual(obj.num_steps, 5)
self.assertEqual(obj.activities, ["CPU", "GPU"])
self.assertTrue(obj.profile_by_stage)
def test_http_api_parameter_validation(self):
# Test with True
json_data = {"merge_profiles": True}
req_input = ProfileReqInput(**json_data)
self.assertTrue(req_input.merge_profiles)
# Test with False
json_data = {"merge_profiles": False}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles)
# Test with string "true" (should be converted by JSON parser)
json_data = {"merge_profiles": "true"}
req_input = ProfileReqInput(**json_data)
self.assertEqual(req_input.merge_profiles, "true") # String, not boolean
def test_http_api_backward_compatibility(self):
# Test minimal request (no merge_profiles)
json_data = {}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles) # Should default to False
# Test with other parameters but no merge_profiles
json_data = {
"output_dir": "/tmp/test",
"num_steps": 5,
"activities": ["CPU", "GPU"],
}
req_input = ProfileReqInput(**json_data)
self.assertFalse(req_input.merge_profiles) # Should default to False
def test_http_api_parameter_combinations(self):
test_cases = [
{
"name": "minimal with merge_profiles",
"data": {"merge_profiles": True},
"expected_merge": True,
},
{
"name": "full parameters with merge_profiles=True",
"data": {
"output_dir": "/tmp/test",
"num_steps": 10,
"activities": ["CPU", "GPU", "MEM"],
"profile_by_stage": True,
"with_stack": True,
"record_shapes": True,
"merge_profiles": True,
},
"expected_merge": True,
},
{
"name": "full parameters with merge_profiles=False",
"data": {
"output_dir": "/tmp/test",
"num_steps": 10,
"activities": ["CPU", "GPU", "MEM"],
"profile_by_stage": False,
"with_stack": False,
"record_shapes": False,
"merge_profiles": False,
},
"expected_merge": False,
},
]
for test_case in test_cases:
with self.subTest(test_case["name"]):
req_input = ProfileReqInput(**test_case["data"])
self.assertEqual(req_input.merge_profiles, test_case["expected_merge"])
if __name__ == "__main__":
unittest.main()