Tool to dump and compare internal activation tensors (#7976)
This commit is contained in:
@@ -1,74 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
|
||||
class _Dumper:
|
||||
"""Utility to dump tensors, which can be useful when comparison checking models.
|
||||
|
||||
Example usage:
|
||||
debug_utils.dumper.dump("layer_start_hidden_states", hidden_states, layer_id=self.layer_id)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._enable = get_bool_env_var("SGLANG_DUMPER_ENABLE", "true")
|
||||
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
|
||||
self._enable_write_file = get_bool_env_var("SGLANG_DUMPER_WRITE_FILE", "1")
|
||||
self._partial_name = str(time.time())
|
||||
self.forward_pass_id = None
|
||||
|
||||
def dump(self, name, value, **kwargs):
|
||||
if not self._enable:
|
||||
return
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
full_kwargs = dict(
|
||||
forward_pass_id=self.forward_pass_id,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
|
||||
path = (
|
||||
self._base_dir / f"sglang_dump_{self._partial_name}_{rank}" / full_filename
|
||||
)
|
||||
|
||||
sample_value = self._get_sample_value(name, value)
|
||||
|
||||
print(
|
||||
f"[{rank}, {time.time()}] {path} "
|
||||
f"type={type(value)} "
|
||||
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
|
||||
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
|
||||
f"sample_value={sample_value}"
|
||||
)
|
||||
|
||||
if self._enable_write_file:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(value, str(path))
|
||||
|
||||
def _get_sample_value(self, name, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return [self._get_sample_value(name, x) for x in value]
|
||||
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return None
|
||||
|
||||
if value.numel() < 200:
|
||||
return value
|
||||
|
||||
slices = [
|
||||
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
|
||||
]
|
||||
return value[tuple(slices)]
|
||||
|
||||
|
||||
dumper = _Dumper()
|
||||
0
python/sglang/srt/debug_utils/__init__.py
Normal file
0
python/sglang/srt/debug_utils/__init__.py
Normal file
131
python/sglang/srt/debug_utils/dump_comparator.py
Normal file
131
python/sglang/srt/debug_utils/dump_comparator.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import argparse
|
||||
import functools
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
from sglang.srt.debug_utils.dumper import get_truncated_value
|
||||
|
||||
|
||||
def main(args):
|
||||
df_target = read_meta(args.target_path)
|
||||
df_target = df_target.sort("rank", "dump_index")
|
||||
df_target = df_target.filter(
|
||||
(pl.col("forward_pass_id") >= args.start_id)
|
||||
& (pl.col("forward_pass_id") <= args.end_id)
|
||||
)
|
||||
assert all(
|
||||
c in df_target.columns
|
||||
for c in ["rank", "forward_pass_id", "dump_index", "name"]
|
||||
)
|
||||
|
||||
df_baseline = read_meta(args.baseline_path)
|
||||
print("df_target", df_target)
|
||||
print("df_baseline", df_baseline)
|
||||
|
||||
for row in df_target.iter_rows(named=True):
|
||||
rows_baseline = df_baseline.filter(
|
||||
(
|
||||
pl.col("forward_pass_id")
|
||||
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
|
||||
)
|
||||
& functools.reduce(
|
||||
lambda a, b: a & b,
|
||||
[
|
||||
pl.col(col) == row[col]
|
||||
for col in row.keys()
|
||||
if col not in ["forward_pass_id", "dump_index", "filename"]
|
||||
],
|
||||
)
|
||||
)
|
||||
assert len(rows_baseline) == 1, f"{rows_baseline=}"
|
||||
row_baseline = rows_baseline.to_dicts()[0]
|
||||
|
||||
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
||||
path_target = Path(args.target_path) / row["filename"]
|
||||
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
||||
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
|
||||
print()
|
||||
|
||||
|
||||
def read_meta(directory):
|
||||
directory = Path(directory)
|
||||
assert directory.is_dir(), f"{directory=} should be a directory"
|
||||
|
||||
rows = []
|
||||
for p in directory.glob("*.pt"):
|
||||
full_kwargs = {}
|
||||
for kv in p.stem.split("___"):
|
||||
k, v = kv.split("=")
|
||||
full_kwargs[k] = v
|
||||
rows.append(
|
||||
{
|
||||
"filename": str(p.name),
|
||||
**full_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
df = pl.DataFrame(rows)
|
||||
df = df.with_columns(
|
||||
pl.col("forward_pass_id").cast(int),
|
||||
pl.col("rank").cast(int),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def check_tensor_pair(path_baseline, path_target):
|
||||
x_baseline = torch.load(path_baseline, weights_only=True)
|
||||
x_target = torch.load(path_target, weights_only=True)
|
||||
|
||||
print(
|
||||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||
)
|
||||
|
||||
if x_baseline.shape != x_target.shape:
|
||||
print(f"❌ Shape mismatch")
|
||||
return
|
||||
|
||||
raw_abs_diff = (x_target - x_baseline).abs()
|
||||
|
||||
max_abs_diff = raw_abs_diff.max().item()
|
||||
mean_abs_diff = raw_abs_diff.mean().item()
|
||||
rel_diff = _calc_rel_diff(x_target, x_baseline)
|
||||
|
||||
needs_print = max_abs_diff > 1e-3
|
||||
|
||||
print(
|
||||
"\t".join(
|
||||
f"{'❌' if value > 1e-3 else '✅'} {name}={value}"
|
||||
for name, value in [
|
||||
("rel_diff", rel_diff),
|
||||
("max_abs_diff", max_abs_diff),
|
||||
("mean_abs_diff", mean_abs_diff),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if needs_print:
|
||||
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
|
||||
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||
|
||||
|
||||
# Copied from DeepGEMM
|
||||
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--baseline-path", type=str)
|
||||
parser.add_argument("--target-path", type=str)
|
||||
parser.add_argument("--start-id", type=int, default=0)
|
||||
parser.add_argument("--end-id", type=int, default=1000000)
|
||||
parser.add_argument("--baseline-start-id", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
108
python/sglang/srt/debug_utils/dumper.py
Normal file
108
python/sglang/srt/debug_utils/dumper.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class _Dumper:
|
||||
"""Utility to dump tensors, which can be useful when comparison checking models.
|
||||
|
||||
Example usage:
|
||||
dumper.on_forward_pass_start()
|
||||
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
|
||||
|
||||
Import from non-SGLang system:
|
||||
```
|
||||
import sys
|
||||
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
|
||||
from dumper import dumper
|
||||
```
|
||||
|
||||
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Do not import `sglang` to make this file standalone
|
||||
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
|
||||
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
|
||||
self._enable_write_file = bool(
|
||||
int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
|
||||
)
|
||||
self._partial_name: Optional[str] = None
|
||||
self._dump_index = 0
|
||||
self._forward_pass_id = 0
|
||||
|
||||
def on_forward_pass_start(self):
|
||||
self._forward_pass_id += 1
|
||||
print(
|
||||
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
|
||||
)
|
||||
|
||||
def dump(self, name, value, **kwargs):
|
||||
if not self._enable:
|
||||
return
|
||||
|
||||
assert (
|
||||
self._forward_pass_id >= 1
|
||||
), "Do you forget to call `dumper.on_forward_pass_start()`?"
|
||||
self._dump_index += 1
|
||||
|
||||
if self._partial_name is None:
|
||||
self._partial_name = _get_partial_name()
|
||||
|
||||
rank = dist.get_rank()
|
||||
full_kwargs = dict(
|
||||
forward_pass_id=self._forward_pass_id,
|
||||
rank=rank,
|
||||
name=name,
|
||||
dump_index=self._dump_index,
|
||||
**kwargs,
|
||||
)
|
||||
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
|
||||
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
|
||||
|
||||
sample_value = get_truncated_value(value)
|
||||
|
||||
print(
|
||||
f"[Dumper] [{rank}, {time.time()}] {path} "
|
||||
f"type={type(value)} "
|
||||
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
|
||||
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
|
||||
f"sample_value={sample_value}"
|
||||
)
|
||||
|
||||
if self._enable_write_file:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(value, str(path))
|
||||
|
||||
|
||||
def _get_partial_name():
|
||||
rank = dist.get_rank()
|
||||
object_list = [str(time.time()) if rank == 0 else None]
|
||||
dist.broadcast_object_list(object_list, device="cuda")
|
||||
return object_list[0]
|
||||
|
||||
|
||||
def get_truncated_value(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return [get_truncated_value(x) for x in value]
|
||||
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return None
|
||||
|
||||
if value.numel() < 200:
|
||||
return value
|
||||
|
||||
slices = [
|
||||
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
|
||||
]
|
||||
return value[tuple(slices)]
|
||||
|
||||
|
||||
dumper = _Dumper()
|
||||
Reference in New Issue
Block a user