First commit
This commit is contained in:
94
pkgs/xformers/profiler/api.py
Normal file
94
pkgs/xformers/profiler/api.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .profiler import (
|
||||
MemSnapshotsProfiler,
|
||||
NsightProfiler,
|
||||
PyTorchProfiler,
|
||||
PyTorchProfiler_CUDAOnly,
|
||||
_Profiler,
|
||||
)
|
||||
from .slow_ops_profiler import DetectSlowOpsProfiler # noqa: F401
|
||||
|
||||
DEFAULT_SCHEDULE = (
|
||||
(MemSnapshotsProfiler, 0, 2),
|
||||
(NsightProfiler, 4, 6),
|
||||
(PyTorchProfiler, 6, 7),
|
||||
(PyTorchProfiler_CUDAOnly, 7, 8),
|
||||
# TODO: There are some issues in PyTorch stable
|
||||
# which are now fixed on main, but might break this profiler
|
||||
# https://github.com/pytorch/pytorch/issues/94403
|
||||
# (DetectSlowOpsProfiler, 9, 10),
|
||||
)
|
||||
|
||||
|
||||
def profile(
|
||||
output_dir: str,
|
||||
module: Optional[nn.Module] = None,
|
||||
schedule: Sequence[Tuple[Any, int, int]] = DEFAULT_SCHEDULE,
|
||||
):
|
||||
"""
|
||||
A pre-configured profiler that will run on the first ~20 steps of the training
|
||||
It will provide multiple traces that can be exploited later.
|
||||
Use it in a context manager around your training loop, and call `xformers.profiler.step`
|
||||
before starting the next iteration.
|
||||
|
||||
:Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import timm.models
|
||||
import xformers.profiler
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
model = timm.models.vit_large_patch16_224().to(device).to(dtype)
|
||||
inp = torch.zeros([64, 3, 224, 224], device=device, dtype=dtype)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
with xformers.profiler.profile(
|
||||
output_dir="profile_data",
|
||||
module=model,
|
||||
schedule=[
|
||||
(MemSnapshotsProfiler, 0, 2),
|
||||
(DetectSlowOpsProfiler, 2, 4),
|
||||
(NsightProfiler, 4, 6),
|
||||
(PyTorchProfiler, 6, 20),
|
||||
]
|
||||
):
|
||||
for i in range(20):
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
xformers.profiler.step()
|
||||
|
||||
# alternatively, use the profiler without context and with ``.start()`` / `.stop()`
|
||||
# calls.
|
||||
|
||||
xprofiler = xformers.profiler.profile(...)
|
||||
xprofiler.start()
|
||||
|
||||
for i in range(20):
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
xprofiler.step()
|
||||
|
||||
xprofiler.stop()
|
||||
"""
|
||||
return _Profiler(output_dir=output_dir, schedule=schedule, module=module)
|
||||
|
||||
|
||||
def step() -> None:
|
||||
"""See `xformers.profiler.profile`"""
|
||||
# Silently return if no profiler is enabled
|
||||
if _Profiler._CURRENT_PROFILER is None:
|
||||
return
|
||||
_Profiler._CURRENT_PROFILER.step()
|
||||
Reference in New Issue
Block a user