Files
2025-08-05 19:02:46 +08:00

95 lines
2.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional, Sequence, Tuple
import torch.nn as nn
from .profiler import (
MemSnapshotsProfiler,
NsightProfiler,
PyTorchProfiler,
PyTorchProfiler_CUDAOnly,
_Profiler,
)
from .slow_ops_profiler import DetectSlowOpsProfiler # noqa: F401
DEFAULT_SCHEDULE = (
(MemSnapshotsProfiler, 0, 2),
(NsightProfiler, 4, 6),
(PyTorchProfiler, 6, 7),
(PyTorchProfiler_CUDAOnly, 7, 8),
# TODO: There are some issues in PyTorch stable
# which are now fixed on main, but might break this profiler
# https://github.com/pytorch/pytorch/issues/94403
# (DetectSlowOpsProfiler, 9, 10),
)
def profile(
output_dir: str,
module: Optional[nn.Module] = None,
schedule: Sequence[Tuple[Any, int, int]] = DEFAULT_SCHEDULE,
):
"""
A pre-configured profiler that will run on the first ~20 steps of the training
It will provide multiple traces that can be exploited later.
Use it in a context manager around your training loop, and call `xformers.profiler.step`
before starting the next iteration.
:Examples:
.. code-block:: python
import torch
import timm.models
import xformers.profiler
dtype = torch.bfloat16
device = "cuda"
model = timm.models.vit_large_patch16_224().to(device).to(dtype)
inp = torch.zeros([64, 3, 224, 224], device=device, dtype=dtype)
optim = torch.optim.Adam(model.parameters())
with xformers.profiler.profile(
output_dir="profile_data",
module=model,
schedule=[
(MemSnapshotsProfiler, 0, 2),
(DetectSlowOpsProfiler, 2, 4),
(NsightProfiler, 4, 6),
(PyTorchProfiler, 6, 20),
]
):
for i in range(20):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
xformers.profiler.step()
# alternatively, use the profiler without context and with ``.start()`` / `.stop()`
# calls.
xprofiler = xformers.profiler.profile(...)
xprofiler.start()
for i in range(20):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
xprofiler.step()
xprofiler.stop()
"""
return _Profiler(output_dir=output_dir, schedule=schedule, module=module)
def step() -> None:
"""See `xformers.profiler.profile`"""
# Silently return if no profiler is enabled
if _Profiler._CURRENT_PROFILER is None:
return
_Profiler._CURRENT_PROFILER.step()