Fix and enhance dumper (#8725)
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import argparse
|
||||
import functools
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
|
||||
from sglang.srt.debug_utils.dumper import get_truncated_value
|
||||
|
||||
|
||||
@@ -26,66 +26,77 @@ def main(args):
|
||||
print("df_baseline", df_baseline)
|
||||
|
||||
for row in df_target.iter_rows(named=True):
|
||||
rows_baseline = df_baseline.filter(
|
||||
(
|
||||
pl.col("forward_pass_id")
|
||||
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
|
||||
)
|
||||
& functools.reduce(
|
||||
lambda a, b: a & b,
|
||||
[
|
||||
pl.col(col) == row[col]
|
||||
for col in row.keys()
|
||||
if col not in ["forward_pass_id", "dump_index", "filename"]
|
||||
],
|
||||
)
|
||||
path_target = Path(args.target_path) / row["filename"]
|
||||
|
||||
row_baseline = find_row(
|
||||
df_baseline,
|
||||
conditions=dict(
|
||||
forward_pass_id=row["forward_pass_id"]
|
||||
- args.start_id
|
||||
+ args.baseline_start_id,
|
||||
**{
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
if k not in ["forward_pass_id", "dump_index", "filename"]
|
||||
},
|
||||
),
|
||||
)
|
||||
assert len(rows_baseline) == 1, f"{rows_baseline=}"
|
||||
row_baseline = rows_baseline.to_dicts()[0]
|
||||
|
||||
if row_baseline is None:
|
||||
print(f"Skip: target={str(path_target)} since no baseline")
|
||||
x_target = _load_object(path_target)
|
||||
if x_target is not None:
|
||||
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||
continue
|
||||
|
||||
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
||||
path_target = Path(args.target_path) / row["filename"]
|
||||
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
||||
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
|
||||
check_tensor_pair(
|
||||
path_baseline=path_baseline, path_target=path_target, name=row["name"]
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def read_meta(directory):
|
||||
directory = Path(directory)
|
||||
assert directory.is_dir(), f"{directory=} should be a directory"
|
||||
|
||||
rows = []
|
||||
for p in directory.glob("*.pt"):
|
||||
full_kwargs = {}
|
||||
for kv in p.stem.split("___"):
|
||||
k, v = kv.split("=")
|
||||
full_kwargs[k] = v
|
||||
rows.append(
|
||||
{
|
||||
"filename": str(p.name),
|
||||
**full_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
df = pl.DataFrame(rows)
|
||||
df = df.with_columns(
|
||||
pl.col("forward_pass_id").cast(int),
|
||||
pl.col("rank").cast(int),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def check_tensor_pair(path_baseline, path_target):
|
||||
x_baseline = torch.load(path_baseline, weights_only=True)
|
||||
x_target = torch.load(path_target, weights_only=True)
|
||||
def check_tensor_pair(path_baseline, path_target, name=""):
|
||||
x_baseline = _load_object(path_baseline)
|
||||
x_target = _load_object(path_target)
|
||||
|
||||
print(
|
||||
f"Raw "
|
||||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||
)
|
||||
|
||||
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
|
||||
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
|
||||
|
||||
print(
|
||||
f"After preprocessor "
|
||||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||
)
|
||||
|
||||
x_target = x_target.float()
|
||||
x_baseline = x_baseline.float()
|
||||
|
||||
for name, fn in (
|
||||
("mean", torch.mean),
|
||||
("std", torch.std),
|
||||
("min", torch.min),
|
||||
("max", torch.max),
|
||||
("p1", functools.partial(torch.quantile, q=0.01)),
|
||||
("p5", functools.partial(torch.quantile, q=0.05)),
|
||||
("p95", functools.partial(torch.quantile, q=0.95)),
|
||||
("p99", functools.partial(torch.quantile, q=0.99)),
|
||||
):
|
||||
value_baseline = fn(x_baseline).item()
|
||||
value_target = fn(x_target).item()
|
||||
print(
|
||||
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
|
||||
)
|
||||
|
||||
if x_baseline.shape != x_target.shape:
|
||||
print(f"❌ Shape mismatch")
|
||||
print(f"⚠️ Shape mismatch")
|
||||
return
|
||||
|
||||
raw_abs_diff = (x_target - x_baseline).abs()
|
||||
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
|
||||
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||
|
||||
|
||||
def _try_unify_shape(x: torch.Tensor, target_shape):
|
||||
x_shape = x.shape
|
||||
num_dim_to_remove = len(x_shape) - len(target_shape)
|
||||
if (x_shape[num_dim_to_remove:] == target_shape) and all(
|
||||
val == 1 for val in x_shape[:num_dim_to_remove]
|
||||
):
|
||||
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
|
||||
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
|
||||
return out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copied from DeepGEMM
|
||||
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def _comparison_preprocessor(x_baseline, x_target, name):
|
||||
# can insert arbitrary adhoc postprocessing logic here
|
||||
return x_baseline, x_target
|
||||
|
||||
|
||||
def _load_object(path):
|
||||
x = torch.load(path, weights_only=False)
|
||||
if not isinstance(x, torch.Tensor):
|
||||
print(f"Skip load {path} since {type(x)=} is not a Tensor")
|
||||
return None
|
||||
return x.cuda()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--baseline-path", type=str)
|
||||
|
||||
97
python/sglang/srt/debug_utils/dump_loader.py
Normal file
97
python/sglang/srt/debug_utils/dump_loader.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import functools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
|
||||
class DumpLoader:
|
||||
def __init__(self):
|
||||
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
|
||||
|
||||
self._enable = directory is not None
|
||||
if self._enable:
|
||||
self._directory = Path(directory)
|
||||
self._df = read_meta(directory)
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
def load(self, name, **kwargs):
|
||||
assert self._enable, "Please call DumpLoader.load only when it is enabled"
|
||||
|
||||
from sglang.srt.debug_utils.dumper import dumper
|
||||
|
||||
forward_pass_id = dumper._forward_pass_id
|
||||
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
|
||||
row = find_row(self._df, conditions=conditions)
|
||||
assert (
|
||||
row is not None
|
||||
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
|
||||
|
||||
path = self._directory / row["filename"]
|
||||
output = torch.load(path, weights_only=False)
|
||||
|
||||
print(
|
||||
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def read_meta(directory):
|
||||
directory = Path(directory)
|
||||
assert directory.is_dir(), f"{directory=} should be a directory"
|
||||
|
||||
rows = []
|
||||
for p in directory.glob("*.pt"):
|
||||
full_kwargs = {}
|
||||
for kv in p.stem.split("___"):
|
||||
k, v = kv.split("=")
|
||||
full_kwargs[k] = v
|
||||
rows.append(
|
||||
{
|
||||
"filename": str(p.name),
|
||||
**full_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
df = pl.DataFrame(rows)
|
||||
df = df.with_columns(
|
||||
pl.col("forward_pass_id").cast(int),
|
||||
pl.col("rank").cast(int),
|
||||
pl.col("dump_index").cast(int),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def find_row(df, conditions: Dict[str, Any]):
|
||||
df_sub = df.filter(
|
||||
functools.reduce(
|
||||
lambda a, b: a & b,
|
||||
[
|
||||
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
|
||||
for col in conditions.keys()
|
||||
],
|
||||
)
|
||||
)
|
||||
assert len(df_sub) <= 1
|
||||
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
|
||||
|
||||
|
||||
def _cast_to_polars_dtype(value, target_dtype):
|
||||
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
|
||||
return int(value)
|
||||
elif target_dtype in (pl.Float64, pl.Float32):
|
||||
return float(value)
|
||||
elif target_dtype == pl.Boolean:
|
||||
return bool(value)
|
||||
elif target_dtype == pl.String:
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
dump_loader = DumpLoader()
|
||||
@@ -53,7 +53,7 @@ class _Dumper:
|
||||
if self._partial_name is None:
|
||||
self._partial_name = _get_partial_name()
|
||||
|
||||
rank = dist.get_rank()
|
||||
rank = _get_rank()
|
||||
full_kwargs = dict(
|
||||
forward_pass_id=self._forward_pass_id,
|
||||
rank=rank,
|
||||
@@ -80,12 +80,20 @@ class _Dumper:
|
||||
|
||||
|
||||
def _get_partial_name():
|
||||
rank = dist.get_rank()
|
||||
rank = _get_rank()
|
||||
object_list = [str(time.time()) if rank == 0 else None]
|
||||
dist.broadcast_object_list(object_list, device="cuda")
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(object_list, device="cuda")
|
||||
return object_list[0]
|
||||
|
||||
|
||||
def _get_rank():
|
||||
if dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_truncated_value(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user