Fix and enhance dumper (#8725)
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import functools
|
import functools
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
|
||||||
from sglang.srt.debug_utils.dumper import get_truncated_value
|
from sglang.srt.debug_utils.dumper import get_truncated_value
|
||||||
|
|
||||||
|
|
||||||
@@ -26,66 +26,77 @@ def main(args):
|
|||||||
print("df_baseline", df_baseline)
|
print("df_baseline", df_baseline)
|
||||||
|
|
||||||
for row in df_target.iter_rows(named=True):
|
for row in df_target.iter_rows(named=True):
|
||||||
rows_baseline = df_baseline.filter(
|
path_target = Path(args.target_path) / row["filename"]
|
||||||
(
|
|
||||||
pl.col("forward_pass_id")
|
row_baseline = find_row(
|
||||||
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
|
df_baseline,
|
||||||
)
|
conditions=dict(
|
||||||
& functools.reduce(
|
forward_pass_id=row["forward_pass_id"]
|
||||||
lambda a, b: a & b,
|
- args.start_id
|
||||||
[
|
+ args.baseline_start_id,
|
||||||
pl.col(col) == row[col]
|
**{
|
||||||
for col in row.keys()
|
k: v
|
||||||
if col not in ["forward_pass_id", "dump_index", "filename"]
|
for k, v in row.items()
|
||||||
],
|
if k not in ["forward_pass_id", "dump_index", "filename"]
|
||||||
)
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert len(rows_baseline) == 1, f"{rows_baseline=}"
|
|
||||||
row_baseline = rows_baseline.to_dicts()[0]
|
if row_baseline is None:
|
||||||
|
print(f"Skip: target={str(path_target)} since no baseline")
|
||||||
|
x_target = _load_object(path_target)
|
||||||
|
if x_target is not None:
|
||||||
|
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||||
|
continue
|
||||||
|
|
||||||
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
||||||
path_target = Path(args.target_path) / row["filename"]
|
|
||||||
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
||||||
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
|
check_tensor_pair(
|
||||||
|
path_baseline=path_baseline, path_target=path_target, name=row["name"]
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def read_meta(directory):
|
def check_tensor_pair(path_baseline, path_target, name=""):
|
||||||
directory = Path(directory)
|
x_baseline = _load_object(path_baseline)
|
||||||
assert directory.is_dir(), f"{directory=} should be a directory"
|
x_target = _load_object(path_target)
|
||||||
|
|
||||||
rows = []
|
|
||||||
for p in directory.glob("*.pt"):
|
|
||||||
full_kwargs = {}
|
|
||||||
for kv in p.stem.split("___"):
|
|
||||||
k, v = kv.split("=")
|
|
||||||
full_kwargs[k] = v
|
|
||||||
rows.append(
|
|
||||||
{
|
|
||||||
"filename": str(p.name),
|
|
||||||
**full_kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
df = pl.DataFrame(rows)
|
|
||||||
df = df.with_columns(
|
|
||||||
pl.col("forward_pass_id").cast(int),
|
|
||||||
pl.col("rank").cast(int),
|
|
||||||
)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_pair(path_baseline, path_target):
|
|
||||||
x_baseline = torch.load(path_baseline, weights_only=True)
|
|
||||||
x_target = torch.load(path_target, weights_only=True)
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
f"Raw "
|
||||||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
|
||||||
|
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After preprocessor "
|
||||||
|
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||||
|
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
x_target = x_target.float()
|
||||||
|
x_baseline = x_baseline.float()
|
||||||
|
|
||||||
|
for name, fn in (
|
||||||
|
("mean", torch.mean),
|
||||||
|
("std", torch.std),
|
||||||
|
("min", torch.min),
|
||||||
|
("max", torch.max),
|
||||||
|
("p1", functools.partial(torch.quantile, q=0.01)),
|
||||||
|
("p5", functools.partial(torch.quantile, q=0.05)),
|
||||||
|
("p95", functools.partial(torch.quantile, q=0.95)),
|
||||||
|
("p99", functools.partial(torch.quantile, q=0.99)),
|
||||||
|
):
|
||||||
|
value_baseline = fn(x_baseline).item()
|
||||||
|
value_target = fn(x_target).item()
|
||||||
|
print(
|
||||||
|
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
if x_baseline.shape != x_target.shape:
|
if x_baseline.shape != x_target.shape:
|
||||||
print(f"❌ Shape mismatch")
|
print(f"⚠️ Shape mismatch")
|
||||||
return
|
return
|
||||||
|
|
||||||
raw_abs_diff = (x_target - x_baseline).abs()
|
raw_abs_diff = (x_target - x_baseline).abs()
|
||||||
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
|
|||||||
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _try_unify_shape(x: torch.Tensor, target_shape):
|
||||||
|
x_shape = x.shape
|
||||||
|
num_dim_to_remove = len(x_shape) - len(target_shape)
|
||||||
|
if (x_shape[num_dim_to_remove:] == target_shape) and all(
|
||||||
|
val == 1 for val in x_shape[:num_dim_to_remove]
|
||||||
|
):
|
||||||
|
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
|
||||||
|
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
|
||||||
|
return out
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Copied from DeepGEMM
|
# Copied from DeepGEMM
|
||||||
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
||||||
x, y = x.double(), y.double()
|
x, y = x.double(), y.double()
|
||||||
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
|||||||
return 1 - sim
|
return 1 - sim
|
||||||
|
|
||||||
|
|
||||||
|
def _comparison_preprocessor(x_baseline, x_target, name):
|
||||||
|
# can insert arbitrary adhoc postprocessing logic here
|
||||||
|
return x_baseline, x_target
|
||||||
|
|
||||||
|
|
||||||
|
def _load_object(path):
|
||||||
|
x = torch.load(path, weights_only=False)
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
print(f"Skip load {path} since {type(x)=} is not a Tensor")
|
||||||
|
return None
|
||||||
|
return x.cuda()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--baseline-path", type=str)
|
parser.add_argument("--baseline-path", type=str)
|
||||||
|
|||||||
97
python/sglang/srt/debug_utils/dump_loader.py
Normal file
97
python/sglang/srt/debug_utils/dump_loader.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import functools
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DumpLoader:
|
||||||
|
def __init__(self):
|
||||||
|
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
|
||||||
|
|
||||||
|
self._enable = directory is not None
|
||||||
|
if self._enable:
|
||||||
|
self._directory = Path(directory)
|
||||||
|
self._df = read_meta(directory)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable(self):
|
||||||
|
return self._enable
|
||||||
|
|
||||||
|
def load(self, name, **kwargs):
|
||||||
|
assert self._enable, "Please call DumpLoader.load only when it is enabled"
|
||||||
|
|
||||||
|
from sglang.srt.debug_utils.dumper import dumper
|
||||||
|
|
||||||
|
forward_pass_id = dumper._forward_pass_id
|
||||||
|
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
|
||||||
|
row = find_row(self._df, conditions=conditions)
|
||||||
|
assert (
|
||||||
|
row is not None
|
||||||
|
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
|
||||||
|
|
||||||
|
path = self._directory / row["filename"]
|
||||||
|
output = torch.load(path, weights_only=False)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def read_meta(directory):
|
||||||
|
directory = Path(directory)
|
||||||
|
assert directory.is_dir(), f"{directory=} should be a directory"
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for p in directory.glob("*.pt"):
|
||||||
|
full_kwargs = {}
|
||||||
|
for kv in p.stem.split("___"):
|
||||||
|
k, v = kv.split("=")
|
||||||
|
full_kwargs[k] = v
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"filename": str(p.name),
|
||||||
|
**full_kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
df = pl.DataFrame(rows)
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("forward_pass_id").cast(int),
|
||||||
|
pl.col("rank").cast(int),
|
||||||
|
pl.col("dump_index").cast(int),
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def find_row(df, conditions: Dict[str, Any]):
|
||||||
|
df_sub = df.filter(
|
||||||
|
functools.reduce(
|
||||||
|
lambda a, b: a & b,
|
||||||
|
[
|
||||||
|
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
|
||||||
|
for col in conditions.keys()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(df_sub) <= 1
|
||||||
|
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_to_polars_dtype(value, target_dtype):
|
||||||
|
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
|
||||||
|
return int(value)
|
||||||
|
elif target_dtype in (pl.Float64, pl.Float32):
|
||||||
|
return float(value)
|
||||||
|
elif target_dtype == pl.Boolean:
|
||||||
|
return bool(value)
|
||||||
|
elif target_dtype == pl.String:
|
||||||
|
return str(value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
dump_loader = DumpLoader()
|
||||||
@@ -53,7 +53,7 @@ class _Dumper:
|
|||||||
if self._partial_name is None:
|
if self._partial_name is None:
|
||||||
self._partial_name = _get_partial_name()
|
self._partial_name = _get_partial_name()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = _get_rank()
|
||||||
full_kwargs = dict(
|
full_kwargs = dict(
|
||||||
forward_pass_id=self._forward_pass_id,
|
forward_pass_id=self._forward_pass_id,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@@ -80,12 +80,20 @@ class _Dumper:
|
|||||||
|
|
||||||
|
|
||||||
def _get_partial_name():
|
def _get_partial_name():
|
||||||
rank = dist.get_rank()
|
rank = _get_rank()
|
||||||
object_list = [str(time.time()) if rank == 0 else None]
|
object_list = [str(time.time()) if rank == 0 else None]
|
||||||
dist.broadcast_object_list(object_list, device="cuda")
|
if dist.is_initialized():
|
||||||
|
dist.broadcast_object_list(object_list, device="cuda")
|
||||||
return object_list[0]
|
return object_list[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rank():
|
||||||
|
if dist.is_initialized():
|
||||||
|
return dist.get_rank()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_truncated_value(value):
|
def get_truncated_value(value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user