diff --git a/python/sglang/srt/debug_utils/dump_comparator.py b/python/sglang/srt/debug_utils/dump_comparator.py index 946cdc4fb..aca9c3b7a 100644 --- a/python/sglang/srt/debug_utils/dump_comparator.py +++ b/python/sglang/srt/debug_utils/dump_comparator.py @@ -1,11 +1,11 @@ import argparse import functools -import re from pathlib import Path import polars as pl import torch +from sglang.srt.debug_utils.dump_loader import find_row, read_meta from sglang.srt.debug_utils.dumper import get_truncated_value @@ -26,66 +26,77 @@ def main(args): print("df_baseline", df_baseline) for row in df_target.iter_rows(named=True): - rows_baseline = df_baseline.filter( - ( - pl.col("forward_pass_id") - == row["forward_pass_id"] - args.start_id + args.baseline_start_id - ) - & functools.reduce( - lambda a, b: a & b, - [ - pl.col(col) == row[col] - for col in row.keys() - if col not in ["forward_pass_id", "dump_index", "filename"] - ], - ) + path_target = Path(args.target_path) / row["filename"] + + row_baseline = find_row( + df_baseline, + conditions=dict( + forward_pass_id=row["forward_pass_id"] + - args.start_id + + args.baseline_start_id, + **{ + k: v + for k, v in row.items() + if k not in ["forward_pass_id", "dump_index", "filename"] + }, + ), ) - assert len(rows_baseline) == 1, f"{rows_baseline=}" - row_baseline = rows_baseline.to_dicts()[0] + + if row_baseline is None: + print(f"Skip: target={str(path_target)} since no baseline") + x_target = _load_object(path_target) + if x_target is not None: + print(f"x_target(sample)={get_truncated_value(x_target)}") + continue path_baseline = Path(args.baseline_path) / row_baseline["filename"] - path_target = Path(args.target_path) / row["filename"] print(f"Check: target={str(path_target)} baseline={str(path_baseline)}") - check_tensor_pair(path_baseline=path_baseline, path_target=path_target) + check_tensor_pair( + path_baseline=path_baseline, path_target=path_target, name=row["name"] + ) print() -def read_meta(directory): - directory = Path(directory) - assert directory.is_dir(), f"{directory=} should be a directory" - - rows = [] - for p in directory.glob("*.pt"): - full_kwargs = {} - for kv in p.stem.split("___"): - k, v = kv.split("=") - full_kwargs[k] = v - rows.append( - { - "filename": str(p.name), - **full_kwargs, - } - ) - - df = pl.DataFrame(rows) - df = df.with_columns( - pl.col("forward_pass_id").cast(int), - pl.col("rank").cast(int), - ) - return df - - -def check_tensor_pair(path_baseline, path_target): - x_baseline = torch.load(path_baseline, weights_only=True) - x_target = torch.load(path_target, weights_only=True) +def check_tensor_pair(path_baseline, path_target, name=""): + x_baseline = _load_object(path_baseline) + x_target = _load_object(path_target) print( + f"Raw " f"[shape] {x_baseline.shape} vs {x_target.shape}\t" f"[dtype] {x_baseline.dtype} vs {x_target.dtype}" ) + x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name) + x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape) + + print( + f"After preprocessor " + f"[shape] {x_baseline.shape} vs {x_target.shape}\t" + f"[dtype] {x_baseline.dtype} vs {x_target.dtype}" + ) + + x_target = x_target.float() + x_baseline = x_baseline.float() + + for name, fn in ( + ("mean", torch.mean), + ("std", torch.std), + ("min", torch.min), + ("max", torch.max), + ("p1", functools.partial(torch.quantile, q=0.01)), + ("p5", functools.partial(torch.quantile, q=0.05)), + ("p95", functools.partial(torch.quantile, q=0.95)), + ("p99", functools.partial(torch.quantile, q=0.99)), + ): + value_baseline = fn(x_baseline).item() + value_target = fn(x_target).item() + print( + f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})" + ) + if x_baseline.shape != x_target.shape: - print(f"❌ Shape mismatch") + print(f"⚠️ Shape mismatch") return raw_abs_diff = (x_target - x_baseline).abs() @@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target): print(f"x_target(sample)={get_truncated_value(x_target)}") +def _try_unify_shape(x: torch.Tensor, target_shape): + x_shape = x.shape + num_dim_to_remove = len(x_shape) - len(target_shape) + if (x_shape[num_dim_to_remove:] == target_shape) and all( + val == 1 for val in x_shape[:num_dim_to_remove] + ): + out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x) + print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})") + return out + + return x + + # Copied from DeepGEMM def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() @@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim +def _comparison_preprocessor(x_baseline, x_target, name): + # can insert arbitrary adhoc postprocessing logic here + return x_baseline, x_target + + +def _load_object(path): + x = torch.load(path, weights_only=False) + if not isinstance(x, torch.Tensor): + print(f"Skip load {path} since {type(x)=} is not a Tensor") + return None + return x.cuda() + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--baseline-path", type=str) diff --git a/python/sglang/srt/debug_utils/dump_loader.py b/python/sglang/srt/debug_utils/dump_loader.py new file mode 100644 index 000000000..8e6f2c79b --- /dev/null +++ b/python/sglang/srt/debug_utils/dump_loader.py @@ -0,0 +1,97 @@ +import functools +import os +from pathlib import Path +from typing import Any, Dict + +import polars as pl +import torch + + +class DumpLoader: + def __init__(self): + directory = os.environ.get("SGLANG_DUMP_LOADER_DIR") + + self._enable = directory is not None + if self._enable: + self._directory = Path(directory) + self._df = read_meta(directory) + + @property + def enable(self): + return self._enable + + def load(self, name, **kwargs): + assert self._enable, "Please call DumpLoader.load only when it is enabled" + + from sglang.srt.debug_utils.dumper import dumper + + forward_pass_id = dumper._forward_pass_id + conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs) + row = find_row(self._df, conditions=conditions) + assert ( + row is not None + ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}" + + path = self._directory / row["filename"] + output = torch.load(path, weights_only=False) + + print( + f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})" + ) + return output + + +def read_meta(directory): + directory = Path(directory) + assert directory.is_dir(), f"{directory=} should be a directory" + + rows = [] + for p in directory.glob("*.pt"): + full_kwargs = {} + for kv in p.stem.split("___"): + k, v = kv.split("=") + full_kwargs[k] = v + rows.append( + { + "filename": str(p.name), + **full_kwargs, + } + ) + + df = pl.DataFrame(rows) + df = df.with_columns( + pl.col("forward_pass_id").cast(int), + pl.col("rank").cast(int), + pl.col("dump_index").cast(int), + ) + return df + + +def find_row(df, conditions: Dict[str, Any]): + df_sub = df.filter( + functools.reduce( + lambda a, b: a & b, + [ + pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col]) + for col in conditions.keys() + ], + ) + ) + assert len(df_sub) <= 1 + return df_sub.to_dicts()[0] if len(df_sub) > 0 else None + + +def _cast_to_polars_dtype(value, target_dtype): + if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32): + return int(value) + elif target_dtype in (pl.Float64, pl.Float32): + return float(value) + elif target_dtype == pl.Boolean: + return bool(value) + elif target_dtype == pl.String: + return str(value) + else: + return value + + +dump_loader = DumpLoader() diff --git a/python/sglang/srt/debug_utils/dumper.py b/python/sglang/srt/debug_utils/dumper.py index d10301241..8a9808bb7 100644 --- a/python/sglang/srt/debug_utils/dumper.py +++ b/python/sglang/srt/debug_utils/dumper.py @@ -53,7 +53,7 @@ class _Dumper: if self._partial_name is None: self._partial_name = _get_partial_name() - rank = dist.get_rank() + rank = _get_rank() full_kwargs = dict( forward_pass_id=self._forward_pass_id, rank=rank, @@ -80,12 +80,20 @@ class _Dumper: def _get_partial_name(): - rank = dist.get_rank() + rank = _get_rank() object_list = [str(time.time()) if rank == 0 else None] - dist.broadcast_object_list(object_list, device="cuda") + if dist.is_initialized(): + dist.broadcast_object_list(object_list, device="cuda") return object_list[0] +def _get_rank(): + if dist.is_initialized(): + return dist.get_rank() + else: + return 0 + + def get_truncated_value(value): if value is None: return None