Tiny improve dumper (#11132)
This commit is contained in:
@@ -36,6 +36,15 @@ class _Dumper:
|
|||||||
self._forward_pass_id = 0
|
self._forward_pass_id = 0
|
||||||
|
|
||||||
def on_forward_pass_start(self):
|
def on_forward_pass_start(self):
|
||||||
|
"""This should be called on all ranks."""
|
||||||
|
|
||||||
|
if not self._enable:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Users may want to `dump` only on some ranks, thus determine name here
|
||||||
|
if self._partial_name is None:
|
||||||
|
self._partial_name = _get_partial_name()
|
||||||
|
|
||||||
self._forward_pass_id += 1
|
self._forward_pass_id += 1
|
||||||
print(
|
print(
|
||||||
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
|
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
|
||||||
@@ -48,11 +57,9 @@ class _Dumper:
|
|||||||
assert (
|
assert (
|
||||||
self._forward_pass_id >= 1
|
self._forward_pass_id >= 1
|
||||||
), "Do you forget to call `dumper.on_forward_pass_start()`?"
|
), "Do you forget to call `dumper.on_forward_pass_start()`?"
|
||||||
|
assert self._partial_name is not None
|
||||||
self._dump_index += 1
|
self._dump_index += 1
|
||||||
|
|
||||||
if self._partial_name is None:
|
|
||||||
self._partial_name = _get_partial_name()
|
|
||||||
|
|
||||||
rank = _get_rank()
|
rank = _get_rank()
|
||||||
full_kwargs = dict(
|
full_kwargs = dict(
|
||||||
forward_pass_id=self._forward_pass_id,
|
forward_pass_id=self._forward_pass_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user