debugging

This commit is contained in:
Chranos
2026-02-10 16:10:28 +08:00
parent 893eeb2208
commit de8fc97532
2 changed files with 99 additions and 5 deletions

View File

@@ -492,6 +492,29 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
return module
def _get_device_memory_info() -> Tuple[Optional[float], Optional[float], Optional[float]]:
"""Get device memory info in GiB. Returns (allocated, reserved, total) or Nones."""
try:
import torch.mlu
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
free, total = torch.mlu.mem_get_info()
total = total / (1024 ** 3)
return allocated, reserved, total
except Exception:
pass
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
free, total = torch.cuda.mem_get_info()
total = total / (1024 ** 3)
return allocated, reserved, total
except Exception:
pass
return None, None, None
def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
@@ -505,11 +528,31 @@ def make_layers(
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
alloc_before, _, total = _get_device_memory_info()
if alloc_before is not None:
logger.info(
"[DEBUG-MEM] make_layers start: allocated=%.2f GiB, "
"total=%.2f GiB, layers to create: %d-%d / %d",
alloc_before, total, start_layer, end_layer, num_hidden_layers)
created_layers = []
for idx in range(start_layer, end_layer):
layer = maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
alloc_after, reserved, _ = _get_device_memory_info()
if alloc_after is not None:
delta = alloc_after - alloc_before
logger.info(
"[DEBUG-MEM] Layer %s.%d created: "
"allocated=%.2f GiB (+%.4f GiB), reserved=%.2f GiB",
prefix, idx, alloc_after, delta, reserved)
alloc_before = alloc_after
created_layers.append(layer)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)]
+ created_layers
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules