Minor add utility to read expert distribution recorder output (#7134)
This commit is contained in:
1
python/sglang/srt/eplb_simulator/__init__.py
Normal file
1
python/sglang/srt/eplb_simulator/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from . import reader
|
||||
51
python/sglang/srt/eplb_simulator/reader.py
Normal file
51
python/sglang/srt/eplb_simulator/reader.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
_convert_global_physical_count_to_logical_count,
|
||||
)
|
||||
|
||||
convert_global_physical_count_to_logical_count = (
|
||||
_convert_global_physical_count_to_logical_count
|
||||
)
|
||||
|
||||
|
||||
def read_mode_per_pass(dir_data: Path):
|
||||
"""Read data from ExpertDistributionRecorder when recorded with mode `per_pass`"""
|
||||
|
||||
# gpc := global_physical_count
|
||||
gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())
|
||||
for path in tqdm(list(dir_data.glob("*.pt"))):
|
||||
data_pack = torch.load(path, weights_only=True)
|
||||
last_physical_to_logical_map = data_pack["last_physical_to_logical_map"]
|
||||
for record in data_pack["records"]:
|
||||
forward_pass_id = record["forward_pass_id"]
|
||||
rank = record["rank"]
|
||||
assert (
|
||||
gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None
|
||||
), f"Duplicated {forward_pass_id=} {rank=}"
|
||||
gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[
|
||||
"global_physical_count"
|
||||
]
|
||||
|
||||
forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())
|
||||
print(f"Make {forward_pass_ids=} into array")
|
||||
|
||||
items = []
|
||||
for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):
|
||||
gpc_of_rank_tensor = torch.stack(
|
||||
[gpc for rank, gpc in sorted(gpc_of_rank.items())]
|
||||
).sum(dim=0)
|
||||
items.append(gpc_of_rank_tensor)
|
||||
|
||||
gpc_of_forward_pass = torch.stack(items)
|
||||
print(f"{gpc_of_forward_pass.shape=}")
|
||||
|
||||
return dict(
|
||||
global_physical_count_of_forward_pass=gpc_of_forward_pass,
|
||||
last_physical_to_logical_map=last_physical_to_logical_map,
|
||||
forward_pass_ids=forward_pass_ids,
|
||||
)
|
||||
Reference in New Issue
Block a user