diff --git a/python/sglang/srt/debug_utils.py b/python/sglang/srt/debug_utils.py deleted file mode 100644 index f019971df..000000000 --- a/python/sglang/srt/debug_utils.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import time -from pathlib import Path - -import torch - -from sglang.srt.utils import get_bool_env_var - - -class _Dumper: - """Utility to dump tensors, which can be useful when comparison checking models. - - Example usage: - debug_utils.dumper.dump("layer_start_hidden_states", hidden_states, layer_id=self.layer_id) - """ - - def __init__(self): - self._enable = get_bool_env_var("SGLANG_DUMPER_ENABLE", "true") - self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp")) - self._enable_write_file = get_bool_env_var("SGLANG_DUMPER_WRITE_FILE", "1") - self._partial_name = str(time.time()) - self.forward_pass_id = None - - def dump(self, name, value, **kwargs): - if not self._enable: - return - - from sglang.srt.distributed import get_tensor_model_parallel_rank - - rank = get_tensor_model_parallel_rank() - full_kwargs = dict( - forward_pass_id=self.forward_pass_id, - name=name, - **kwargs, - ) - full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt" - path = ( - self._base_dir / f"sglang_dump_{self._partial_name}_{rank}" / full_filename - ) - - sample_value = self._get_sample_value(name, value) - - print( - f"[{rank}, {time.time()}] {path} " - f"type={type(value)} " - f"shape={value.shape if isinstance(value, torch.Tensor) else None} " - f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} " - f"sample_value={sample_value}" - ) - - if self._enable_write_file: - path.parent.mkdir(parents=True, exist_ok=True) - torch.save(value, str(path)) - - def _get_sample_value(self, name, value): - if value is None: - return None - - if isinstance(value, tuple): - return [self._get_sample_value(name, x) for x in value] - - if not isinstance(value, torch.Tensor): - return None - - if value.numel() < 200: - return value - - slices = [ - slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape - ] - return value[tuple(slices)] - - -dumper = _Dumper() diff --git a/python/sglang/srt/debug_utils/__init__.py b/python/sglang/srt/debug_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/srt/debug_utils/dump_comparator.py b/python/sglang/srt/debug_utils/dump_comparator.py new file mode 100644 index 000000000..946cdc4fb --- /dev/null +++ b/python/sglang/srt/debug_utils/dump_comparator.py @@ -0,0 +1,131 @@ +import argparse +import functools +import re +from pathlib import Path + +import polars as pl +import torch + +from sglang.srt.debug_utils.dumper import get_truncated_value + + +def main(args): + df_target = read_meta(args.target_path) + df_target = df_target.sort("rank", "dump_index") + df_target = df_target.filter( + (pl.col("forward_pass_id") >= args.start_id) + & (pl.col("forward_pass_id") <= args.end_id) + ) + assert all( + c in df_target.columns + for c in ["rank", "forward_pass_id", "dump_index", "name"] + ) + + df_baseline = read_meta(args.baseline_path) + print("df_target", df_target) + print("df_baseline", df_baseline) + + for row in df_target.iter_rows(named=True): + rows_baseline = df_baseline.filter( + ( + pl.col("forward_pass_id") + == row["forward_pass_id"] - args.start_id + args.baseline_start_id + ) + & functools.reduce( + lambda a, b: a & b, + [ + pl.col(col) == row[col] + for col in row.keys() + if col not in ["forward_pass_id", "dump_index", "filename"] + ], + ) + ) + assert len(rows_baseline) == 1, f"{rows_baseline=}" + row_baseline = rows_baseline.to_dicts()[0] + + path_baseline = Path(args.baseline_path) / row_baseline["filename"] + path_target = Path(args.target_path) / row["filename"] + print(f"Check: target={str(path_target)} baseline={str(path_baseline)}") + check_tensor_pair(path_baseline=path_baseline, path_target=path_target) + print() + + +def read_meta(directory): + directory = Path(directory) + assert directory.is_dir(), f"{directory=} should be a directory" + + rows = [] + for p in directory.glob("*.pt"): + full_kwargs = {} + for kv in p.stem.split("___"): + k, v = kv.split("=") + full_kwargs[k] = v + rows.append( + { + "filename": str(p.name), + **full_kwargs, + } + ) + + df = pl.DataFrame(rows) + df = df.with_columns( + pl.col("forward_pass_id").cast(int), + pl.col("rank").cast(int), + ) + return df + + +def check_tensor_pair(path_baseline, path_target): + x_baseline = torch.load(path_baseline, weights_only=True) + x_target = torch.load(path_target, weights_only=True) + + print( + f"[shape] {x_baseline.shape} vs {x_target.shape}\t" + f"[dtype] {x_baseline.dtype} vs {x_target.dtype}" + ) + + if x_baseline.shape != x_target.shape: + print(f"❌ Shape mismatch") + return + + raw_abs_diff = (x_target - x_baseline).abs() + + max_abs_diff = raw_abs_diff.max().item() + mean_abs_diff = raw_abs_diff.mean().item() + rel_diff = _calc_rel_diff(x_target, x_baseline) + + needs_print = max_abs_diff > 1e-3 + + print( + "\t".join( + f"{'❌' if value > 1e-3 else '✅'} {name}={value}" + for name, value in [ + ("rel_diff", rel_diff), + ("max_abs_diff", max_abs_diff), + ("mean_abs_diff", mean_abs_diff), + ] + ) + ) + + if needs_print: + print(f"x_baseline(sample)={get_truncated_value(x_baseline)}") + print(f"x_target(sample)={get_truncated_value(x_target)}") + + +# Copied from DeepGEMM +def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--baseline-path", type=str) + parser.add_argument("--target-path", type=str) + parser.add_argument("--start-id", type=int, default=0) + parser.add_argument("--end-id", type=int, default=1000000) + parser.add_argument("--baseline-start-id", type=int, default=0) + args = parser.parse_args() + main(args) diff --git a/python/sglang/srt/debug_utils/dumper.py b/python/sglang/srt/debug_utils/dumper.py new file mode 100644 index 000000000..d10301241 --- /dev/null +++ b/python/sglang/srt/debug_utils/dumper.py @@ -0,0 +1,108 @@ +import os +import time +from pathlib import Path +from typing import Optional + +import torch +import torch.distributed as dist + + +class _Dumper: + """Utility to dump tensors, which can be useful when comparison checking models. + + Example usage: + dumper.on_forward_pass_start() + dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id) + + Import from non-SGLang system: + ``` + import sys + sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils") + from dumper import dumper + ``` + + Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison + """ + + def __init__(self): + # Do not import `sglang` to make this file standalone + self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1"))) + self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp")) + self._enable_write_file = bool( + int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1")) + ) + self._partial_name: Optional[str] = None + self._dump_index = 0 + self._forward_pass_id = 0 + + def on_forward_pass_start(self): + self._forward_pass_id += 1 + print( + f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}" + ) + + def dump(self, name, value, **kwargs): + if not self._enable: + return + + assert ( + self._forward_pass_id >= 1 + ), "Do you forget to call `dumper.on_forward_pass_start()`?" + self._dump_index += 1 + + if self._partial_name is None: + self._partial_name = _get_partial_name() + + rank = dist.get_rank() + full_kwargs = dict( + forward_pass_id=self._forward_pass_id, + rank=rank, + name=name, + dump_index=self._dump_index, + **kwargs, + ) + full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt" + path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename + + sample_value = get_truncated_value(value) + + print( + f"[Dumper] [{rank}, {time.time()}] {path} " + f"type={type(value)} " + f"shape={value.shape if isinstance(value, torch.Tensor) else None} " + f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} " + f"sample_value={sample_value}" + ) + + if self._enable_write_file: + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(value, str(path)) + + +def _get_partial_name(): + rank = dist.get_rank() + object_list = [str(time.time()) if rank == 0 else None] + dist.broadcast_object_list(object_list, device="cuda") + return object_list[0] + + +def get_truncated_value(value): + if value is None: + return None + + if isinstance(value, tuple): + return [get_truncated_value(x) for x in value] + + if not isinstance(value, torch.Tensor): + return None + + if value.numel() < 200: + return value + + slices = [ + slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape + ] + return value[tuple(slices)] + + +dumper = _Dumper()